In [None]:
# default_exp data

In [None]:
#all_slow

# Data
> Classes and functions for managing data

In [None]:
#hide
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
#export
from lemonpie.basics import *
from lemonpie.preprocessing.transform import *
from fastai.imports import *
import copy, glob

In [None]:
#hide
from nbdev.showdoc import *

## Split

- Splitting is already done in the raw data before vocab creation.
- The following class is to load and manage the pre-processed splits together.

In [None]:
# export
class EHRDataSplits:
    """Class to hold the PatientList splits."""

    def __init__(self, path, age_start, age_range, start_is_date, age_in_months):

        self.splits, self.modality_types = self._load_splits(
            path, age_start, age_range, start_is_date, age_in_months
        )
        # self.train, self.valid, self.test = self.splits.values()

    def _load_splits(self, path, age_start, age_range, start_is_date, age_in_months):
        """Load splits of preprocessed `PatientList`s from persistent store using path."""
        splits = {}
        modality_types = {}
        for split in ["train", "valid", "test"]:
            pckl_dir = get_pckl_dir(
                path, split, 999, age_start, age_range, age_in_months
            )
            mod_types = [
                mod_type.name.split("_")[-1] for mod_type in pckl_dir.parent.iterdir()
            ]

            modality_types[split] = mod_types
            splits[split] = [
                PatientList.load(
                    path=path,
                    split=split,
                    modality_type=m_type,
                    age_start=age_start,
                    age_range=age_range,
                    start_is_date=start_is_date,
                    age_in_months=age_in_months,
                )
                for m_type in mod_types
            ]

        return splits, modality_types

    def get_splits_modtypes(self):
        """Return splits and modality types."""
        return self.splits, self.modality_types

    def get_lengths(self):
        """Return a dataframe with lengths (# of patients) of the splits (train, valid, test) and total."""
        lengths = []
        train, valid, test = self.splits.values()

        for split in [train, valid, test]:
            lengths.append(sum([len(ptlist) for ptlist in split]))
        lengths.append(sum(lengths))
        return pd.DataFrame(
            lengths, index=["train", "valid", "test", "total"], columns=["lengths"]
        )

    def get_label_counts(self, labels):
        """Get prevalence counts of labels in each split 
        returns a dataframe with counts for each split and total count."""

        train, valid, test = self.splits.values()

        # flatten for each split
        train_ptlist = [ptlist for mod_type in train for ptlist in mod_type]
        valid_ptlist = [ptlist for mod_type in valid for ptlist in mod_type]
        test_ptlist = [ptlist for mod_type in test for ptlist in mod_type]

        counts = []
        for label in labels:
            train_count = [
                train_ptlist[i].conditions[label] == 1 for i in range(len(train_ptlist))
            ].count(True)
            valid_count = [
                valid_ptlist[i].conditions[label] == 1 for i in range(len(valid_ptlist))
            ].count(True)
            test_count = [
                test_ptlist[i].conditions[label] == 1 for i in range(len(test_ptlist))
            ].count(True)
            total_count = train_count + valid_count + test_count
            counts.append([train_count, valid_count, test_count, total_count])
        return pd.DataFrame(
            counts, index=labels, columns=["train", "valid", "test", "total"]
        )

    def get_pos_wts(self, labels):
        """Get positive weights to be used in `nn.BCEWithLogitsLoss`."""
        pos_counts = self.get_label_counts(labels)
        neg_counts = self.get_lengths().transpose().values - pos_counts
        return round(neg_counts / pos_counts)


In [None]:
show_doc(EHRDataSplits, title_level=3)

<h3 id="EHRDataSplits" class="doc_header"><code>class</code> <code>EHRDataSplits</code><a href="" class="source_link" style="float:right">[source]</a></h3>

> <code>EHRDataSplits</code>(**`path`**, **`age_start`**, **`age_range`**, **`start_is_date`**, **`age_in_months`**)

Class to hold the PatientList splits.

In [None]:
show_doc(EHRDataSplits._load_splits)

<h4 id="EHRDataSplits._load_splits" class="doc_header"><code>EHRDataSplits._load_splits</code><a href="__main__.py#L12" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRDataSplits._load_splits</code>(**`path`**, **`age_start`**, **`age_range`**, **`start_is_date`**, **`age_in_months`**)

Load splits of preprocessed [`PatientList`](/lemonpie/preprocessing_transform.html#PatientList)s from persistent store using path.

In [None]:
show_doc(EHRDataSplits.get_splits_modtypes)

<h4 id="EHRDataSplits.get_splits_modtypes" class="doc_header"><code>EHRDataSplits.get_splits_modtypes</code><a href="__main__.py#L40" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRDataSplits.get_splits_modtypes</code>()

Return splits and modality types.

In [None]:
show_doc(EHRDataSplits.get_lengths)

<h4 id="EHRDataSplits.get_lengths" class="doc_header"><code>EHRDataSplits.get_lengths</code><a href="__main__.py#L44" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRDataSplits.get_lengths</code>()

Return a dataframe with lengths (# of patients) of the splits (train, valid, test) and total.

In [None]:
show_doc(EHRDataSplits.get_label_counts)

<h4 id="EHRDataSplits.get_label_counts" class="doc_header"><code>EHRDataSplits.get_label_counts</code><a href="__main__.py#L56" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRDataSplits.get_label_counts</code>(**`labels`**)

Get prevalence counts of labels in each split 
returns a dataframe with counts for each split and total count.

In [None]:
show_doc(EHRDataSplits.get_pos_wts)

<h4 id="EHRDataSplits.get_pos_wts" class="doc_header"><code>EHRDataSplits.get_pos_wts</code><a href="__main__.py#L84" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRDataSplits.get_pos_wts</code>(**`labels`**)

Get positive weights to be used in `nn.BCEWithLogitsLoss`.

#### Tests on Synthea

In [None]:
PATH_1K, CONDITIONS

('/home/vinod/.lemonpie/datasets/synthea/1K',
 {'diabetes': '44054006',
  'stroke': '230690007',
  'alzheimers': '26929004',
  'coronary_heart': '53741008',
  'lung_cancer': '254637007',
  'breast_cancer': '254837009',
  'rheumatoid_arthritis': '69896004',
  'epilepsy': '84757009'})

In [None]:
labels = list(CONDITIONS.keys())

In [None]:
labels

['diabetes',
 'stroke',
 'alzheimers',
 'coronary_heart',
 'lung_cancer',
 'breast_cancer',
 'rheumatoid_arthritis',
 'epilepsy']

In [None]:
# splits = EHRDataSplits(PATH_1K, age_start='2000-01-01', age_range=17, start_is_date=True, age_in_months=False)

In [None]:
lengths = splits.get_lengths()
lengths

Unnamed: 0,lengths
train,702
valid,234
test,235
total,1171


In [None]:
prevalence = splits.get_label_counts(labels)
prevalence

Unnamed: 0,train,valid,test,total
diabetes,43,14,19,76
stroke,30,7,11,48
alzheimers,12,7,6,25
coronary_heart,39,11,11,61
lung_cancer,12,0,2,14
breast_cancer,11,8,2,21
rheumatoid_arthritis,2,0,0,2
epilepsy,15,5,2,22


In [None]:
splits.get_pos_wts(labels)

Unnamed: 0,train,valid,test,total
diabetes,15.0,16.0,11.0,14.0
stroke,22.0,32.0,20.0,23.0
alzheimers,58.0,32.0,38.0,46.0
coronary_heart,17.0,20.0,20.0,18.0
lung_cancer,58.0,inf,116.0,83.0
breast_cancer,63.0,28.0,116.0,55.0
rheumatoid_arthritis,350.0,inf,inf,584.0
epilepsy,46.0,46.0,116.0,52.0


In [None]:
lengths.transpose().values

array([[ 702,  234,  235, 1171]])

In [None]:
neg_counts = lengths.transpose().values - prevalence
neg_counts

Unnamed: 0,train,valid,test,total
diabetes,659,220,216,1095
stroke,672,227,224,1123
alzheimers,690,227,229,1146
coronary_heart,663,223,224,1110
lung_cancer,690,234,233,1157
breast_cancer,691,226,233,1150
rheumatoid_arthritis,700,234,235,1169
epilepsy,687,229,233,1149


In [None]:
round(neg_counts / prevalence)

Unnamed: 0,train,valid,test,total
diabetes,15.0,16.0,11.0,14.0
stroke,22.0,32.0,20.0,23.0
alzheimers,58.0,32.0,38.0,46.0
coronary_heart,17.0,20.0,20.0,18.0
lung_cancer,58.0,inf,116.0,83.0
breast_cancer,63.0,28.0,116.0,55.0
rheumatoid_arthritis,350.0,inf,inf,584.0
epilepsy,46.0,46.0,116.0,52.0


**Cross check with raw**
- Check total counts against raw_csv
- Check split counts against split/raw_csv

In [None]:
raw_cnds = pd.read_csv(f'{PATH_1K}/raw_original/conditions.csv', low_memory=False)

In [None]:
cnd_codes = list(CONDITIONS.values())
cnd_codes

['44054006',
 '230690007',
 '26929004',
 '53741008',
 '254637007',
 '254837009',
 '69896004',
 '84757009']

In [None]:
int(CONDITIONS['diabetes'])

44054006

In [None]:
for label in labels:
    print(label,':: ', raw_cnds[raw_cnds.CODE == int(CONDITIONS[label])].CODE.count())

diabetes ::  76
stroke ::  48
alzheimers ::  25
coronary_heart ::  61
lung_cancer ::  14
breast_cancer ::  21
rheumatoid_arthritis ::  2
epilepsy ::  22


In [None]:
raw_cnds_train = pd.read_csv(f'{PATH_1K}/raw_split/train/conditions.csv', low_memory=False)
raw_cnds_valid = pd.read_csv(f'{PATH_1K}/raw_split/valid/conditions.csv', low_memory=False)
raw_cnds_test  = pd.read_csv(f'{PATH_1K}/raw_split/test/conditions.csv', low_memory=False)

In [None]:
for label in labels:
    assert prevalence.loc[label].total == raw_cnds[raw_cnds.CODE == int(CONDITIONS[label])].CODE.count()
    assert prevalence.loc[label].train == raw_cnds_train[raw_cnds_train.CODE == int(CONDITIONS[label])].CODE.count()
    assert prevalence.loc[label].valid == raw_cnds_valid[raw_cnds_valid.CODE == int(CONDITIONS[label])].CODE.count()
    assert prevalence.loc[label].test  == raw_cnds_test [raw_cnds_test.CODE == int(CONDITIONS[label])]. CODE.count()

#### Tests on Coherent

In [None]:
COHERENT_DATA_STORE = '/home/vinod/code/datasets/coherent'
COHERENT_DATAGEN_DATE = '08-10-2021'
COHERENT_CONDITIONS = {
    "heart_failure" : "88805009",
    "coronary_heart" : "53741008",
    "myocardial_infarction" : "22298006",
    "stroke" : "230690007",
    "cardiac_arrest" : "410429000"
}

In [None]:
labels = list(COHERENT_CONDITIONS.keys())

In [None]:
labels

['heart_failure',
 'coronary_heart',
 'myocardial_infarction',
 'stroke',
 'cardiac_arrest']

In [None]:
coherent_splits = EHRDataSplits(COHERENT_DATA_STORE,  age_start=240, age_range=120, start_is_date=False, age_in_months=True)

In [None]:
lengths = coherent_splits.get_lengths()
lengths

Unnamed: 0,lengths
train,1022
valid,128
test,128
total,1278


In [None]:
prevalence = coherent_splits.get_label_counts(labels)
prevalence

Unnamed: 0,train,valid,test,total
heart_failure,257,43,32,332
coronary_heart,260,38,38,336
myocardial_infarction,110,15,20,145
stroke,573,61,64,698
cardiac_arrest,137,20,23,180


In [None]:
coherent_splits.get_pos_wts(labels)

Unnamed: 0,train,valid,test,total
heart_failure,3.0,2.0,3.0,3.0
coronary_heart,3.0,2.0,2.0,3.0
myocardial_infarction,8.0,8.0,5.0,8.0
stroke,1.0,1.0,1.0,1.0
cardiac_arrest,6.0,5.0,5.0,6.0


In [None]:
lengths.transpose().values

array([[1022,  128,  128, 1278]])

In [None]:
neg_counts = lengths.transpose().values - prevalence
neg_counts

Unnamed: 0,train,valid,test,total
heart_failure,765,85,96,946
coronary_heart,762,90,90,942
myocardial_infarction,912,113,108,1133
stroke,449,67,64,580
cardiac_arrest,885,108,105,1098


In [None]:
round(neg_counts / prevalence)

Unnamed: 0,train,valid,test,total
heart_failure,3.0,2.0,3.0,3.0
coronary_heart,3.0,2.0,2.0,3.0
myocardial_infarction,8.0,8.0,5.0,8.0
stroke,1.0,1.0,1.0,1.0
cardiac_arrest,6.0,5.0,5.0,6.0


**Cross check with raw**
- Check total counts against raw_csv
- Check split counts against split/raw_csv

In [None]:
raw_cnds = pd.read_csv(f'{COHERENT_DATA_STORE}/raw_original/conditions.csv', low_memory=False)

In [None]:
cnd_codes = list(COHERENT_CONDITIONS.values())
cnd_codes

['88805009', '53741008', '22298006', '230690007', '410429000']

In [None]:
for label in labels:
    print(label,':: ', raw_cnds[raw_cnds.CODE == int(COHERENT_CONDITIONS[label])].CODE.count())

heart_failure ::  332
coronary_heart ::  336
myocardial_infarction ::  145
stroke ::  698
cardiac_arrest ::  180


In [None]:
raw_cnds_train = pd.read_csv(f'{COHERENT_DATA_STORE}/raw_split/train/conditions.csv', low_memory=False)
raw_cnds_valid = pd.read_csv(f'{COHERENT_DATA_STORE}/raw_split/valid/conditions.csv', low_memory=False)
raw_cnds_test  = pd.read_csv(f'{COHERENT_DATA_STORE}/raw_split/test/conditions.csv', low_memory=False)

In [None]:
for label in labels:
    assert prevalence.loc[label].total == raw_cnds[raw_cnds.CODE == int(COHERENT_CONDITIONS[label])].CODE.count()
    assert prevalence.loc[label].train == raw_cnds_train[raw_cnds_train.CODE == int(COHERENT_CONDITIONS[label])].CODE.count()
    assert prevalence.loc[label].valid == raw_cnds_valid[raw_cnds_valid.CODE == int(COHERENT_CONDITIONS[label])].CODE.count()
    assert prevalence.loc[label].test  == raw_cnds_test [raw_cnds_test.CODE == int(COHERENT_CONDITIONS[label])]. CODE.count()

## Label

**Labeling** definition in fastai -- some processes need to be run on `train` and **applied** to `valid`

This is completed in preprocessing (vocab & transform) as follows
1. Vocabs created from train data
    - Tokenizing unique values for different record codes & demographic values
    - Calculating mean and std for age
2. Vocabs applied to train, valid and test data
    - With `numericalize` for record codes & demographic values
    - With normalizing of age with the mean / std from train

**Hence labeling in our case will be creating X and y**

- X is the patient object
- y (for a single patient) needs to be a tensor made out of the patient's values for labels ('diabetes', 'stroke', 'alzheimers', 'coronary_heart', 'lung_cancer') 

So **creating the `y` tensor** is simply a matter of ..
1. extracting the values of each of the labels from each `Patient` object 
2. turning it into a `torch.FloatTensor`
3. and stacking them up using `torch.stack`

In [None]:
tst_y = np.array((True, False, False, True), dtype='float')
torch.from_numpy(tst_y), torch.FloatTensor(tst_y)

(tensor([1., 0., 0., 1.], dtype=torch.float64), tensor([1., 0., 0., 1.]))

2 ways of creating torch tensor from a numpy array, we will stick with the latter

In [None]:
y = []
for pt in splits.valid[0]:
    y.append(torch.FloatTensor(np.array([pt.conditions[label] for label in labels], dtype='float')) )

In [None]:
# y

In [None]:
y = torch.stack(y)

In [None]:
y.shape

torch.Size([3, 5])

In [None]:
y

tensor([[1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.]])

Putting it into a function

In [None]:
def label_data(patient_ds, labels) -> 'x,y':
    '''Extracts y from patient object, returns x=Patient object, y=tensor of conditions'''
    def _get_y(ds, labels):
        y = []
        for pt in ds:
            y.append( torch.FloatTensor(np.array([pt.conditions[label] for label in labels], dtype='float')) )
        return torch.stack(y)
    
    x, y = patient_ds, _get_y(patient_ds, labels)
    return x,y

In [None]:
x_train,y_train = label_data(splits.train, labels)
x_valid,y_valid = label_data(splits.valid, labels)
x_test ,y_test  = label_data(splits.test , labels)

In [None]:
y_train.shape, y_valid.shape, y_test.shape

(torch.Size([702, 8]), torch.Size([234, 8]), torch.Size([235, 8]))

In [None]:
torch.full((10,1), 2)

tensor([[2],
        [2],
        [2],
        [2],
        [2],
        [2],
        [2],
        [2],
        [2],
        [2]])

In [None]:
#export
class LabelEHRData():
    '''Class to hold labeled EHR data splits'''
    def __init__(self, train, valid, test, labels):
        '''Extracts y from patient object, each labelset a tuple of x,y: x=Patient object, y=tensor of conditions'''
        self.x_train, self.y_train = train, self._get_y(train, labels)
        self.x_valid, self.y_valid = valid, self._get_y(valid, labels)
        self.x_test,  self.y_test  = test , self._get_y(test , labels)
        
        self.train = self.x_train, self.y_train
        self.valid = self.x_valid, self.y_valid
        self.test  = self.x_test,  self.y_test
    
    def _get_y(self, ds, labels):
        '''Extract y from each patient object in ds and stack them - ds is dataset containing patient objects'''
        y = []
        for pt in ds:
            y.append( torch.FloatTensor(np.array([pt.conditions[label] for label in labels], dtype='float')) )
        return torch.stack(y)

In [None]:
show_doc(LabelEHRData, title_level=3)

<h3 id="LabelEHRData" class="doc_header"><code>class</code> <code>LabelEHRData</code><a href="" class="source_link" style="float:right">[source]</a></h3>

> <code>LabelEHRData</code>(**`train`**, **`valid`**, **`test`**, **`labels`**)

Class to hold labeled EHR data splits

In [None]:
show_doc(LabelEHRData.__init__)

<h4 id="LabelEHRData.__init__" class="doc_header"><code>LabelEHRData.__init__</code><a href="__main__.py#L4" class="source_link" style="float:right">[source]</a></h4>

> <code>LabelEHRData.__init__</code>(**`train`**, **`valid`**, **`test`**, **`labels`**)

Extracts y from patient object, each labelset a tuple of x,y: x=Patient object, y=tensor of conditions

In [None]:
show_doc(LabelEHRData._get_y)

<h4 id="LabelEHRData._get_y" class="doc_header"><code>LabelEHRData._get_y</code><a href="__main__.py#L14" class="source_link" style="float:right">[source]</a></h4>

> <code>LabelEHRData._get_y</code>(**`ds`**, **`labels`**)

Extract y from each patient object in ds and stack them - ds is dataset containing patient objects

In [None]:
labeled = LabelEHRData(*splits.get_splits(), labels)

In [None]:
labeled.train

(PatientList (702 items)
 base path:/home/vinod/.lemonpie/datasets/synthea/1K; split:train
 age_start:2000-01-01; age_range:17; age_type:years
 ptid:0ace3e15-8aa4-41c5-8b90-2408285ebcfe, birthdate:1986-04-02 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu
 ptid:af1495be-5077-4087-98b1-9ff624c7582c, birthdate:2008-07-17 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu
 ptid:f23e12d9-2ec6-4006-b041-ea78d374e9c9, birthdate:2014-09-06 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu
 ptid:1968aa31-5fce-461a-9486-6e385a7b75e7, birthdate:1986-04-11 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu
 ptid:1211c8ff-ab73-49f3-b2ab-87b7a03f6167, birthdate:1972-03-24 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu
 ptid:27a8b7b6-007d-4036-82a7-80a9ab670dcb, birthdate:2005-04-13 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu
 ptid:532696f2-0b76-4eb0-9aea-a74e2fb1bed2, birthdate:1967-05-18 00:00:00, [(

## Dataset

[Subclasses](https://pytorch.org/docs/master/data.html?highlight=dataloader#torch.utils.data.Dataset) `torch.utils.data.Dataset`<br>
- that is implements `__len__()` and `__getitem__()`

In [None]:
# export
class EHRDataset(torch.utils.data.Dataset):
    """Class to hold a single EHR dataset (holds a tuple of x, y & m for modality type).
    Also handles lazy vs full loading of dataset on GPU."""

    def __init__(
        self,
        ptlist: list,
        labels: list,
        modality_type: int,
        lazy_load_gpu: bool = True,
    ):
        """Extract y, create x,y: x=Patient object, y=tensor of conditions
        If `lazy_load_gpu` is `False`, load entire dataset on GPU."""

        self.x, self.y = ptlist, self._get_y(ptlist, labels)
        # self.m = torch.full((len(ptlist), 1), modality_type)
        self.m = modality_type
        self.lazy = lazy_load_gpu

        if self.lazy == False:
            self.x = [pt.to_gpu() for pt in self.x]
            self.y = self.y.to(DEVICE)
            self.m = self.m.to(DEVICE)

    def _get_y(self, ptlist, labels):
        """Extract y from each patient object in ptlist and stack them."""
        y = []
        for pt in ptlist:
            y.append(
                torch.FloatTensor(
                    np.array([pt.conditions[label] for label in labels], dtype="float")
                )
            )
        return torch.stack(y)

    def __len__(self):
        return len(self.x)

    def _test_getitem(self, i):
        return self.x[i], self.y[i], self.m

    def __getitem__(self, i):
        """If lazy loading, return deep copy of patient object `i`
        else entire dataset already on GPU - just return `i`."""
        if self.lazy:
            return copy.deepcopy(self.x[i]), self.y[i], self.m  # make m[i] if tensor
        else:
            return self.x[i], self.y[i], self.m


In [None]:
show_doc(EHRDataset, title_level=3)

<h3 id="EHRDataset" class="doc_header"><code>class</code> <code>EHRDataset</code><a href="" class="source_link" style="float:right">[source]</a></h3>

> <code>EHRDataset</code>(**\*`args`**, **\*\*`kwds`**) :: `Dataset`

Class to hold a single EHR dataset (holds a tuple of x, y & m for modality type).
Also handles lazy vs full loading of dataset on GPU.

In [None]:
show_doc(EHRDataset.__init__)

<h4 id="EHRDataset.__init__" class="doc_header"><code>EHRDataset.__init__</code><a href="__main__.py#L6" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRDataset.__init__</code>(**`ptlist`**:`list`, **`labels`**:`list`, **`modality_type`**:`int`, **`lazy_load_gpu`**:`bool`=*`True`*)

Extract y, create x,y: x=Patient object, y=tensor of conditions
If `lazy_load_gpu` is `False`, load entire dataset on GPU.

In [None]:
show_doc(EHRDataset._get_y)

<h4 id="EHRDataset._get_y" class="doc_header"><code>EHRDataset._get_y</code><a href="__main__.py#L26" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRDataset._get_y</code>(**`ptlist`**, **`labels`**)

Extract y from each patient object in ptlist and stack them.

In [None]:
show_doc(EHRDataset.__getitem__)

<h4 id="EHRDataset.__getitem__" class="doc_header"><code>EHRDataset.__getitem__</code><a href="__main__.py#L43" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRDataset.__getitem__</code>(**`i`**)

If lazy loading, return deep copy of patient object `i`
else entire dataset already on GPU - just return `i`.

Since `Patient` is a custom object and not a typical tensor, we need to handle the behavior for `Dataset`, `DataLoader`, etc to function correctly.
- Memory pinning is a good idea for better performance if lazy loading to GPU
    - [A discussion - pin memory vs full load to GPU](https://discuss.pytorch.org/t/pin-memory-vs-sending-direct-to-gpu-from-dataset/33891)
- So when a DataLoader pins memory on a tensor and copy of the tensor is made on page-locked memory in RAM as opposed to swappable memory which speed up transfers to GPU
    - [A good explanation](https://stackoverflow.com/questions/5736968/why-is-cuda-pinned-memory-so-fast)
- But on custom data type like our `Patient` object, we need to define the behavior
    - [Pytorch docs](https://pytorch.org/docs/stable/data.html#memory-pinning)
- Making a [deep copy](https://docs.python.org/3/library/copy.html) of the `Patient`object to mimick tensor behavior
    - Otherwise, given the Patient holds it's changed tensors, all tensors are CUDA tensors after the first epoch and DL tries to pin memory again and this causes an error (TODO: Need to elaborate)

### Handling Multimodal Data

#### Split-level Dataset & Custom Batch Sampler - one each for train, valid & test

For each split (train, valid and test), we use  
- a `ConcatDataset` to hold multiple `MultimodalDataset`s.
- a custom batch sampler `ModalityTypeBatchSampler` that creates batches with the same modality type.

Modality type is the combination of data modalities available for a given patient.
| Modality Type | Modalities            |   
|---	        |---	                |
| **0**	        | **EHR**               |
| **1**         | EHR + **MRI**         |
| **10**        | EHR + **DNA**         |      
| 11   	        | EHR + MRI + DNA       |
| **20**        | EHR + **ECG**         |
| 21            | EHR + MRI + ECG       |
| 30            | EHR + DNA + ECG       |
| 31            | EHR + MRI + DNA + ECG |




Solution used here is from this Pytorch forum discussion 
- https://discuss.pytorch.org/t/how-to-concatenate-different-datasets-each-with-different-dimensions/123218
- The `ConcatDataset` holds all the specific `MultimodalDataset`s together - one for each modality combination.
    - For example - (EHR + MRI + ECG + Notes), (EHR + MRI), (EHR + DNA + MRI + Notes), etc. 
- The custom batch sampler ensures that each batch only has elements from one of the `MultimodalDataset`s.
    - But also provides shuffling across the various types.

In [None]:
# export


class ModalityTypeBatchSampler(Sampler):
    """Custom BatchSampler for multimodal data."""

    def __init__(self, indices_list: list, batch_size: int, shuffle: bool):
        """Init with indicies from every modality-type dataset and create all batches."""
        self.indices_list = indices_list
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.all_batches = self._create_batches()

    def _chunk(self, indices, size):
        """Chunk indices into batch size."""
        return torch.split(torch.tensor(indices), size)

    def _create_batches(self):
        """Create batches."""
        all_batches = []
        for indices in self.indices_list:
            if self.shuffle:
                random.shuffle(indices)
            all_batches.extend(self._chunk(indices, self.batch_size))
        all_batches = [batch.tolist() for batch in all_batches]

        return all_batches

    def __iter__(self):
        """Iterable used by dataloaders."""
        if self.shuffle:
            random.shuffle(self.all_batches)
        return iter(self.all_batches)

    def __len__(self):
        """Return length based on concated datasets."""
        return len(self.all_batches)


In [None]:
show_doc(ModalityTypeBatchSampler, title_level=3)

<h3 id="ModalityTypeBatchSampler" class="doc_header"><code>class</code> <code>ModalityTypeBatchSampler</code><a href="" class="source_link" style="float:right">[source]</a></h3>

> <code>ModalityTypeBatchSampler</code>(**\*`args`**, **\*\*`kwds`**) :: `Sampler`

Custom BatchSampler for multimodal data.

In [None]:
show_doc(ModalityTypeBatchSampler.__init__)

<h4 id="ModalityTypeBatchSampler.__init__" class="doc_header"><code>ModalityTypeBatchSampler.__init__</code><a href="__main__.py#L7" class="source_link" style="float:right">[source]</a></h4>

> <code>ModalityTypeBatchSampler.__init__</code>(**`indices_list`**:`list`, **`batch_size`**:`int`, **`shuffle`**:`bool`)

Init with indicies from every modality-type dataset and create all batches.

In [None]:
show_doc(ModalityTypeBatchSampler._chunk)

<h4 id="ModalityTypeBatchSampler._chunk" class="doc_header"><code>ModalityTypeBatchSampler._chunk</code><a href="__main__.py#L14" class="source_link" style="float:right">[source]</a></h4>

> <code>ModalityTypeBatchSampler._chunk</code>(**`indices`**, **`size`**)

Chunk indices into batch size.

In [None]:
show_doc(ModalityTypeBatchSampler._create_batches)

<h4 id="ModalityTypeBatchSampler._create_batches" class="doc_header"><code>ModalityTypeBatchSampler._create_batches</code><a href="__main__.py#L18" class="source_link" style="float:right">[source]</a></h4>

> <code>ModalityTypeBatchSampler._create_batches</code>()

Create batches.

In [None]:
show_doc(ModalityTypeBatchSampler.__iter__)

<h4 id="ModalityTypeBatchSampler.__iter__" class="doc_header"><code>ModalityTypeBatchSampler.__iter__</code><a href="__main__.py#L29" class="source_link" style="float:right">[source]</a></h4>

> <code>ModalityTypeBatchSampler.__iter__</code>()

Iterable used by dataloaders.

In [None]:
show_doc(ModalityTypeBatchSampler.__len__)

<h4 id="ModalityTypeBatchSampler.__len__" class="doc_header"><code>ModalityTypeBatchSampler.__len__</code><a href="__main__.py#L35" class="source_link" style="float:right">[source]</a></h4>

> <code>ModalityTypeBatchSampler.__len__</code>()

Return length based on concated datasets.

In [None]:
# export
def create_modality_ds_sampler(
    ehr_dataset_list: list, batch_size: int, shuffle: bool
):
    """Create a custom ConcatDataset and BatchSampler for modality types."""

    modtype_dataset = torch.utils.data.ConcatDataset(ehr_dataset_list)
    indxs = modtype_dataset.cumulative_sizes

    indicies_list = []
    for i in range(len(ehr_dataset_list)):
        if i == 0:
            indx_range = range(indxs[0])
        else:
            indx_range = range(indxs[i - 1], indxs[i])
        indicies_list.append(list(indx_range))

    batch_sampler = ModalityTypeBatchSampler(indicies_list, batch_size, shuffle)

    return modtype_dataset, batch_sampler


In [None]:
show_doc(create_modality_ds_sampler)

<h4 id="create_modality_ds_sampler" class="doc_header"><code>create_modality_ds_sampler</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>create_modality_ds_sampler</code>(**`ehr_dataset_list`**:`list`, **`batch_size`**:`int`, **`shuffle`**:`bool`)

Create a custom ConcatDataset and BatchSampler for modality types.

#### Multimodal Datasets & Custom Collate Functions


- https://discuss.pytorch.org/t/train-simultaneously-on-two-datasets/649
- This approach  can be used to simultaneously read from 2 Datasets and get a tuple.

In [None]:
# export
class MultimodalDataset(torch.utils.data.Dataset):
    """Multimodal dataset for EHR plus other modalities."""

    def __init__(self, ds_list):
        """Separate EHR and other modalities."""
        self.ehr_ds = ds_list[0]
        if len(ds_list) > 1:
            self.other_ds_list = ds_list[1:]

    def __getitem__(self, i):
        """Get patient_ids from EHRDataset and
        use them to fetch data of other modalities."""

        ehr = self.ehr_ds[i]
        patient = ehr[0]
        ptid = patient.ptid
        if hasattr(self, "other_ds_list"):
            if len(self.other_ds_list) == 1:
                return ehr, self.other_ds_list[0][ptid]
            else:
                return ehr, tuple(ds[ptid] for ds in self.other_ds_list)
        else:
            return ehr, None

    def __len__(self):
        """Return count of patients in this modality type."""
        return len(self.ehr_ds)


In [None]:
# export
def collate_ehr(batch):
    """Custom collate fn for EHR plus 3 other modalities."""
    ehr, other = zip(*batch)
    pts, ys, ms = zip(*ehr)
    ys = torch.stack(ys)

    if ms[0] in [1, 10, 20]:
        other = torch.stack(other)
    if ms[0] in [11, 21, 30]:
        mod1, mod2 = zip(*other)
        other = (torch.stack(mod1), torch.stack(mod2))
    if ms[0] == 31:
        mod1, mod2, mod3 = zip(*other)
        other = (torch.stack(mod1), torch.stack(mod2), torch.stack(mod3))

    return pts, ys, ms, other


##### Quick Tests

In [None]:
splits, mod_types = coherent_splits.get_splits_modtypes()

In [None]:
mod_types

{'train': ['0', '21', '30', '31', '11', '1', '20', '10'],
 'valid': ['21', '30', '31', '11', '1', '20', '10'],
 'test': ['21', '30', '31', '11', '1', '20', '10']}

In [None]:
train_ptlists = splits["train"]
train_mod_types = mod_types["train"]
assert len(train_ptlists) == len(train_mod_types)

In [None]:
class UnimodalDataset(torch.utils.data.Dataset):
    def __init__(self, type: int, tensor_sz):
        super().__init__()
        self.type = type
        self.tensor_sz = tensor_sz
    
    def __getitem__(self, i):
        return torch.full(self.tensor_sz, self.type)

    def __len__(self):
        return 100


In [None]:
mri_ds = UnimodalDataset(1, (3,3))
ecg_ds = UnimodalDataset(2, (5,))
dna_ds = UnimodalDataset(3, (4,2))
# notes_ds = UnimodalDataset("notes")

In [None]:
type0_ehr_ds = EHRDataset(train_ptlists[1], labels, 0)
type1_ehr_ds = EHRDataset(train_ptlists[1], labels, 1)
type2_ehr_ds = EHRDataset(train_ptlists[1], labels, 21)
type3_ehr_ds = EHRDataset(train_ptlists[1], labels, 31)

In [None]:
type0_mm_ds = MultimodalDataset([type0_ehr_ds])
type1_mm_ds = MultimodalDataset([type1_ehr_ds, mri_ds])
type2_mm_ds = MultimodalDataset([type2_ehr_ds, ecg_ds, mri_ds])
type3_mm_ds = MultimodalDataset([type3_ehr_ds, dna_ds, ecg_ds, mri_ds])

In [None]:
train_ds, train_sampler = create_modality_ds_sampler([type0_mm_ds, type1_mm_ds, type2_mm_ds, type3_mm_ds], batch_size=4, shuffle=True)
train_dl = DataLoader(train_ds,  batch_sampler=train_sampler, num_workers=0, collate_fn=collate_ehr)

In [None]:
pts, ys, ms, other = next(iter(train_dl))
ms

(0, 0, 0, 0)

In [None]:
other

(None, None, None, None)

In [None]:
pts

(ptid:fbb75ebb-8b10-2a28-5634-3b8f4da7442b, birthdate:1939-08-12, [('heart_failure', True), ('coronary_heart', True)].., device:cpu,
 ptid:e4aa3da5-5600-6038-9b92-a93278fbe3ed, birthdate:1935-12-22, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
 ptid:c12614ed-2db3-cb59-33c6-be45d445192e, birthdate:1930-12-12, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
 ptid:53931443-cc8d-9ac0-9ce8-8f3cc30712ab, birthdate:1932-02-20, [('heart_failure', True), ('coronary_heart', True)].., device:cpu)

#### Unimodal Datasets - One Per Modality

In [None]:
# export

class MRIDataset(torch.utils.data.Dataset):
    def __init__(self, mri_dir: str, tensor_sz: tuple):
        super().__init__()
        self.mri_dir = mri_dir
        self.tensor_sz = tensor_sz
    
    def __getitem__(self, i):
        mri_fname = glob.glob(f"{self.mri_dir}/*{i}*")
        if len(mri_fname) == 1:
            return torch.full(self.tensor_sz, 1)
        else:
            raise Exception(f"MRI filename match error - found {len(mri_fname)} files with ptid: {i}.")

    def __len__(self):
        return 1


In [None]:
# export

class DNADataset(torch.utils.data.Dataset):
    def __init__(self, dna_dir: str, tensor_sz: tuple):
        super().__init__()
        self.dna_dir = dna_dir
        self.tensor_sz = tensor_sz
    
    def __getitem__(self, i):
        dna_fname = glob.glob(f"{self.dna_dir}/*{i}*")
        if len(dna_fname) == 1:
            return torch.full(self.tensor_sz, 10)
        else:
            raise Exception(f"DNA filename match error - found {len(dna_fname)} files with ptid: {i}.")

    def __len__(self):
        return 1


In [None]:
# export

class ECGDataset(torch.utils.data.Dataset):
    def __init__(self, ecg_dir: str, tensor_sz: tuple):
        super().__init__()
        ecg_data = pd.read_csv(f"{ecg_dir}/ecg.csv")
        self.ecg_pids = ecg_data.patient.unique()
        self.tensor_sz = tensor_sz
    
    def __getitem__(self, i):
        
        if i in self.ecg_pids:
            return torch.full(self.tensor_sz, 20)
        else:
            raise Exception(f"ptid: {i} - not found in ECG data.")

    def __len__(self):
        return 1


In [None]:
mri_ds = MRIDataset(f"{COHERENT_DATA_STORE}/output/dicom", (4,4))
dna_ds = DNADataset(f"{COHERENT_DATA_STORE}/output/dna", (3,2))
ecg_ds = ECGDataset(f"{COHERENT_DATA_STORE}", (5,))

In [None]:
# mri_ds["e87a6fbb-f0c0-fac0-4207-93698009e721"] # exception
mri_ds["5293a3a9-777a-ffdf-d9ec-3439c0115d231"]

tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])

In [None]:
# dna_ds["5293a3a9-777a-ffdf-d9ec-3439c0115d231"] # exception
dna_ds["e87a6fbb-f0c0-fac0-4207-93698009e721"]

tensor([[10, 10],
        [10, 10],
        [10, 10]])

In [None]:
# ecg_ds["9c452d24-00b0-d58f-4cd5-b82bd6695647"] # exception
ecg_ds["9c452d24-00b0-d58f-4cd5-b82bd6695646"]


tensor([20, 20, 20, 20, 20])

#### Tests on Synthea

In [None]:
def get_ds(x_train, y_train, x_valid, y_valid, modality_type) -> 'train_ds, valid_ds':
    train_ds,valid_ds = EHRDataset(x_train, y_train, modality_type), EHRDataset(x_valid, y_valid, modality_type)
    return train_ds, valid_ds

**Testing Lazy Load**

In [None]:
train_ds, valid_ds = get_ds(*labeled.train, *labeled.valid, 3)

In [None]:
len(train_ds), len(valid_ds)

(702, 234)

In [None]:
len(labeled.train), len(labeled.x_train)

(2, 702)

In [None]:
assert len(train_ds)==len(labeled.x_train)==len(labeled.y_train)
assert len(valid_ds)==len(labeled.y_valid)==len(labeled.x_valid)

In [None]:
xb,yb, mb = train_ds[0:7]
xb,yb, mb

([ptid:0ace3e15-8aa4-41c5-8b90-2408285ebcfe, birthdate:1986-04-02 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:af1495be-5077-4087-98b1-9ff624c7582c, birthdate:2008-07-17 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:f23e12d9-2ec6-4006-b041-ea78d374e9c9, birthdate:2014-09-06 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:1968aa31-5fce-461a-9486-6e385a7b75e7, birthdate:1986-04-11 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:1211c8ff-ab73-49f3-b2ab-87b7a03f6167, birthdate:1972-03-24 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:27a8b7b6-007d-4036-82a7-80a9ab670dcb, birthdate:2005-04-13 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:532696f2-0b76-4eb0-9aea-a74e2fb1bed2, birthdate:1967-05-18 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu],
 tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.

In [None]:
yb.shape, mb.shape

(torch.Size([7, 8]), torch.Size([7, 1]))

In [None]:
xb[0].obs_nums.is_pinned()

False

In [None]:
train_ds._test_getitem(0)

(ptid:0ace3e15-8aa4-41c5-8b90-2408285ebcfe, birthdate:1986-04-02 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
 tensor([0., 0., 0., 0., 0., 0., 0., 0.]),
 tensor([3]))

#### Testing multimodal with toy data

In [None]:
class ToyEHR_DS(torch.utils.data.Dataset):
    """Toy EHR Dataset for testing multimodal functionality."""

    def __init__(
        self,
        x_labeled: list,
        y_labeled: Tensor,
        modality_type: int,
    ):
        # self.m = torch.full((len(x_labeled), 1), modality_type)
        self.m = modality_type
        self.x, self.y = x_labeled, y_labeled


    def __len__(self):
        return len(self.x)

    def _test_getitem(self, i):
        return self.x[i], self.y[i], self.m

    def __getitem__(self, i):
            return self.x[i], self.y[i], self.m


In [None]:
type1_ds = ToyEHR_DS(['type1_1', 'type1_2', 'type1_3', 'type1_4'], torch.tensor((1, 0, 1, 0)), 1)
type2_ds = ToyEHR_DS(['type2_1', 'type2_2', 'type2_3', 'type2_4', 'type2_5', 'type2_6', 'type2_7'], torch.tensor((1, 0, 1, 0, 0, 1, 1)), 2)
type3_ds = ToyEHR_DS(['type3_1', 'type3_2', 'type3_3', 'type3_4', 'type3_5'], torch.tensor((1, 0, 1, 0, 0)), 3)

In [None]:
type2_ds[3], len(type2_ds)

(('type2_4', tensor(0), 2), 7)

`shuffle = False`

In [None]:
modtype_ds, sampler = create_modality_ds_sampler([type1_ds, type2_ds, type3_ds], batch_size=2, shuffle=False)

In [None]:
dl = DataLoader(modtype_ds,  batch_sampler=sampler)
len(dl)

9

In [None]:
for i, (x, y, m) in enumerate(dl):
    print(f"i={i} -- x={x}") #, y:{y} \n m:{m}")

i=0 -- x=('type1_1', 'type1_2')
i=1 -- x=('type1_3', 'type1_4')
i=2 -- x=('type2_1', 'type2_2')
i=3 -- x=('type2_3', 'type2_4')
i=4 -- x=('type2_5', 'type2_6')
i=5 -- x=('type2_7',)
i=6 -- x=('type3_1', 'type3_2')
i=7 -- x=('type3_3', 'type3_4')
i=8 -- x=('type3_5',)


`shuffle = True`

In [None]:
mm_ds, sampler = create_modality_ds_sampler([type1_ds, type2_ds, type3_ds], batch_size=2, shuffle=True)

In [None]:
dl = DataLoader(mm_ds,  batch_sampler=sampler)
len(dl)

9

In [None]:
for i, (x, y, m) in enumerate(dl):
    print(f"i={i} -- x={x}") # \n y:{y} \n m:{m}")

i=0 -- x=('type3_5', 'type3_3')
i=1 -- x=('type2_7', 'type2_2')
i=2 -- x=('type1_2', 'type1_3')
i=3 -- x=('type3_2',)
i=4 -- x=('type2_6', 'type2_4')
i=5 -- x=('type2_5', 'type2_1')
i=6 -- x=('type2_3',)
i=7 -- x=('type3_1', 'type3_4')
i=8 -- x=('type1_4', 'type1_1')


**Single modality type** - for example just EHR tabular.

In [None]:
mm_ds, sampler = create_modality_ds_sampler([type2_ds], batch_size=2, shuffle=False)

In [None]:
dl = DataLoader(mm_ds,  batch_sampler=sampler)
len(dl)

4

In [None]:
for i, (x, y, m) in enumerate(dl):
    print(f"i={i} -- x={x}") # \n y:{y} \n m:{m}")

i=0 -- x=('type2_1', 'type2_2')
i=1 -- x=('type2_3', 'type2_4')
i=2 -- x=('type2_5', 'type2_6')
i=3 -- x=('type2_7',)


**Testing Multimodal functionality**

In [None]:
class UnimodalDataset(torch.utils.data.Dataset):
    def __init__(self, type: str):
        super().__init__()
        self.type = type
    
    def __getitem__(self, i):
        return f"{self.type}-{i}"

    def __len__(self):
        return 100


In [None]:
class ToyMMDataset(torch.utils.data.Dataset):
    def __init__(self, ds_list):
        self.ehr_dataset = ds_list[0]
        self.other_datasets = ds_list[1:]

    def __getitem__(self, i):
        pts, _, _ = self.ehr_dataset[i]
        # ptids = [patient.ptid for patient in pts]
        return self.ehr_dataset[i], tuple(d[pts] for d in self.other_datasets)

    def __len__(self):
        return len(self.ehr_dataset)

In [None]:
type1_ehr_ds = ToyEHR_DS(['type1_1', 'type1_2', 'type1_3', 'type1_4'], torch.tensor((1, 0, 1, 0)), 1)
type2_ehr_ds = ToyEHR_DS(['type2_1', 'type2_2', 'type2_3', 'type2_4', 'type2_5', 'type2_6', 'type2_7'], torch.tensor((1, 0, 1, 0, 0, 1, 1)), 2)
type3_ehr_ds = ToyEHR_DS(['type3_1', 'type3_2', 'type3_3', 'type3_4', 'type3_5'], torch.tensor((1, 0, 1, 0, 0)), 3)

ehr_batch_sz = 2

mri_ds = UnimodalDataset("mri")
ecg_ds = UnimodalDataset("ecg")
dna_ds = UnimodalDataset("dna")
notes_ds = UnimodalDataset("notes")

type1_mm_ds = ToyMMDataset([type1_ehr_ds, mri_ds, ecg_ds])
type2_mm_ds = ToyMMDataset([type2_ehr_ds, dna_ds, notes_ds])
type3_mm_ds = ToyMMDataset([type3_ehr_ds, notes_ds, mri_ds, ecg_ds])



modtype_ds, modtype_sampler = create_modality_ds_sampler([type1_mm_ds, type2_mm_ds, type3_mm_ds], batch_size=ehr_batch_sz, shuffle=True)
ehr_dl = DataLoader(modtype_ds,  batch_sampler=modtype_sampler)

In [None]:
modtype_ds.cumulative_sizes

[4, 11, 16]

In [None]:
next(iter(ehr_dl))

[[('type2_3', 'type2_5'), tensor([1, 0]), tensor([2, 2])],
 [('dna-type2_3', 'dna-type2_5'), ('notes-type2_3', 'notes-type2_5')]]

#### Tests on Coherent

In [None]:
coherent_splits = EHRDataSplits(COHERENT_DATA_STORE,  age_start=240, age_range=120, start_is_date=False, age_in_months=True)

In [None]:
splits, mod_types = coherent_splits.get_splits_modtypes()

In [None]:
mod_types

{'train': ['0', '21', '30', '31', '11', '1', '20', '10'],
 'valid': ['21', '30', '31', '11', '1', '20', '10'],
 'test': ['21', '30', '31', '11', '1', '20', '10']}

In [None]:
# export


def get_dls(splits: dict, modality_types: dict, labels: list, datastore: str, batch_size: int, num_workers: int):
    dls = {}
    for split in ["train", "valid", "test"]:

        ptlists = splits[split]
        mod_types = modality_types[split]
        assert len(ptlists) == len(mod_types)

        multimodal_ds_list = []
        for ptlist, mod_type in zip(ptlists, mod_types):
            mod_type = int(mod_type)

            unimodal_ds_list = []

            # EHR
            unimodal_ds_list.append(
                EHRDataset(ptlist=ptlist, labels=labels, modality_type=mod_type)
            )
            # + MRI
            if mod_type in [1, 11, 21, 31]:
                unimodal_ds_list.append(
                    MRIDataset(f"{datastore}/output/dicom", (4, 4))
                )
            # + DNA
            if mod_type in [10, 11, 30, 31]:
                unimodal_ds_list.append(
                    DNADataset(f"{datastore}/output/dna", (3, 2))
                )
            # + ECG
            if mod_type in [20, 21, 30, 31]:
                unimodal_ds_list.append(ECGDataset(f"{datastore}", (5,)))

            multimodal_ds_list.append(MultimodalDataset(unimodal_ds_list))

        shuffle = True if split == "train" else False
        ds, sampler = create_modality_ds_sampler(
            multimodal_ds_list, batch_size=batch_size, shuffle=shuffle
        )
        dl = DataLoader(
            dataset=ds, batch_sampler=sampler, num_workers=0, collate_fn=collate_ehr
        )
        dls[split] = dl
        
    return dls


In [None]:
for split in ["train", "valid"]:
    print(mod_types[split])

['0', '21', '30', '31', '11', '1', '20', '10']
['21', '30', '31', '11', '1', '20', '10']


In [None]:
dls = get_dls(splits, mod_types, batch_size=4, num_workers=cpu_cnt/2)

In [None]:
dls

{'train': <torch.utils.data.dataloader.DataLoader at 0x7f7e01c84040>,
 'valid': <torch.utils.data.dataloader.DataLoader at 0x7f7e01e4b1c0>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x7f7e01e77a90>}

In [None]:
train_dl = dls["train"]
valid_dl = dls["valid"]
test_dl = dls["test"]

In [None]:
batch = next(iter(train_dl))
batch

((ptid:57c53e54-8f99-3393-8cc3-94e98bfbdc68, birthdate:1928-10-29, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
  ptid:f12ff989-9593-71f1-66e6-502ec9b3fecd, birthdate:1946-04-26, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
  ptid:182c006c-b98a-0b2c-39f9-ac12d2ac732b, birthdate:1917-08-14, [('heart_failure', False), ('coronary_heart', True)].., device:cpu,
  ptid:2ace6fc1-10d4-dffd-0ad7-0ff8cef81bff, birthdate:1934-12-15, [('heart_failure', True), ('coronary_heart', False)].., device:cpu),
 tensor([[1., 0., 0., 1., 0.],
         [1., 0., 0., 1., 0.],
         [0., 1., 0., 1., 1.],
         [1., 0., 0., 1., 0.]]),
 (31, 31, 31, 31),
 (tensor([[[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]],
  
          [[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]],
  
          [[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]],
  
  

In [None]:
pts, ys, ms, other = batch

In [None]:
if ms[0] == 31:
    mri, dna, ecg = other

In [None]:
ecg

tensor([[20, 20, 20, 20, 20],
        [20, 20, 20, 20, 20],
        [20, 20, 20, 20, 20],
        [20, 20, 20, 20, 20]])

In [None]:
batch = next(iter(valid_dl))
batch

((ptid:1a82483d-7eb2-d5e0-1e1f-398ba129b18b, birthdate:1936-12-22, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
  ptid:6dc8bd6b-e2a8-92bf-613d-8b477eb87d7c, birthdate:1911-12-23, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
  ptid:844a37ff-ce26-6338-fd6a-0bc1e925a702, birthdate:1933-03-15, [('heart_failure', True), ('coronary_heart', False)].., device:cpu),
 tensor([[1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.]]),
 (21, 21, 21),
 (tensor([[[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]],
  
          [[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]],
  
          [[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]]]),
  tensor([[20, 20, 20, 20, 20],
          [20, 20, 20, 20, 20],
          [20, 20, 20, 20, 20]])))

In [None]:
batch = next(iter(test_dl))
batch

((ptid:8f61d438-d822-b7b0-2184-8a0762fc11f3, birthdate:1914-09-18, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,),
 tensor([[1., 0., 0., 0., 0.]]),
 (21,),
 (tensor([[[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]]]),
  tensor([[20, 20, 20, 20, 20]])))

## DataLoader - Using Pytorch DataLoader

**Need to define a custom collate function**, because default collate cannot handle list of patient objects in x, gives following error
```
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class '__main__.Patient'>
```

In [None]:
valid_ds[0:4]

([ptid:8d1ba4bb-7250-4295-be1c-5d0d423e55f7, birthdate:1957-02-13 00:00:00, [('diabetes', True), ('stroke', False)].., device:cpu,
  ptid:f1921fc3-fdfc-441d-a928-27c18002fedf, birthdate:1909-12-22 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:fc4aa89c-e441-4c0b-841f-3d16ffe1b235, birthdate:1981-04-24 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:4e0be087-7a33-4655-a9c0-f00f23178ac1, birthdate:1977-02-03 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu],
 tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]]),
 tensor([[3],
         [3],
         [3],
         [3]]))

In [None]:
x_tmps,y_tmps, m_tmps = valid_ds[0:4]

In [None]:
x_tmps

[ptid:8d1ba4bb-7250-4295-be1c-5d0d423e55f7, birthdate:1957-02-13 00:00:00, [('diabetes', True), ('stroke', False)].., device:cpu,
 ptid:f1921fc3-fdfc-441d-a928-27c18002fedf, birthdate:1909-12-22 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
 ptid:fc4aa89c-e441-4c0b-841f-3d16ffe1b235, birthdate:1981-04-24 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
 ptid:4e0be087-7a33-4655-a9c0-f00f23178ac1, birthdate:1977-02-03 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu]

In [None]:
y_tmps

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [None]:
m_tmps

tensor([[3],
        [3],
        [3],
        [3]])

**Old collate fns**

**1. removed cuda calls**
```python
def collate(b):
    xs,ys = zip(*b)
    return [x.to_gpu() for x in xs], torch.unsqueeze(torch.tensor(ys), 1).cuda()
```
**2. removed unsqueeze**
```python
def collate(b):
    xs,ys = zip(*b)
    return xs, torch.unsqueeze(torch.tensor(ys), 1)
```

In [None]:
def collate_ehr(b):
    '''Custom collate function for use in `DataLoader`'''
    xs,ys, ms = zip(*b)
    return xs, torch.stack(ys), torch.stack(ms)

In [None]:
bs = 2

In [None]:
def get_dls(train_ds, valid_ds, bs, collate_fn=collate_ehr, lazy=True) -> 'train_dl, valid_dl':
    return(DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, pin_memory=lazy),
           DataLoader(valid_ds, batch_size=bs*2, collate_fn=collate_fn, pin_memory=lazy))

In [None]:
train_dl, valid_dl = get_dls(train_ds, valid_ds, bs)

**Tests - `iter()`, `next()` - Next Batch**

In [None]:
it = iter(valid_dl)
first_x, first_y, first_m = next(it)
second_x, second_y, second_m = next(it)

In [None]:
first_x, first_y, first_m

([ptid:8d1ba4bb-7250-4295-be1c-5d0d423e55f7, birthdate:1957-02-13 00:00:00, [('diabetes', True), ('stroke', False)].., device:cpu,
  ptid:f1921fc3-fdfc-441d-a928-27c18002fedf, birthdate:1909-12-22 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:fc4aa89c-e441-4c0b-841f-3d16ffe1b235, birthdate:1981-04-24 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:4e0be087-7a33-4655-a9c0-f00f23178ac1, birthdate:1977-02-03 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu],
 tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]]),
 tensor([[3],
         [3],
         [3],
         [3]]))

In [None]:
first_x[3].med_offsts.is_pinned(), first_y.is_pinned()

(True, True)

In [None]:
second_x, second_y, second_m

([ptid:6d048a56-edb8-4f29-891d-7a84d75a8e78, birthdate:1914-09-05 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:4fc76a3b-e39e-4091-a6af-3595e0cb607e, birthdate:1948-06-01 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:26ca976d-0b5b-4662-af41-535ff670dd5a, birthdate:2014-09-22 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu,
  ptid:59486a8b-389b-4355-9df4-edc62bbd1a11, birthdate:1951-10-11 00:00:00, [('diabetes', False), ('stroke', False)].., device:cpu],
 tensor([[0., 0., 1., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]]),
 tensor([[3],
         [3],
         [3],
         [3]]))

In [None]:
second_x[0].alg_nums

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [None]:
second_x[0].alg_nums.is_pinned()

True

**Testing full GPU loading (non-Lazy)**

In [None]:
train_ds = EHRDataset(*labeled.train, modality_type=0, lazy_load_gpu=False)
valid_ds = EHRDataset(*labeled.valid, modality_type=2, lazy_load_gpu=False)

In [None]:
xb,yb,mb = train_ds[0:5]
xb,yb,mb

([ptid:0ace3e15-8aa4-41c5-8b90-2408285ebcfe, birthdate:1986-04-02 00:00:00, [('diabetes', False), ('stroke', False)].., device:cuda:0,
  ptid:af1495be-5077-4087-98b1-9ff624c7582c, birthdate:2008-07-17 00:00:00, [('diabetes', False), ('stroke', False)].., device:cuda:0,
  ptid:f23e12d9-2ec6-4006-b041-ea78d374e9c9, birthdate:2014-09-06 00:00:00, [('diabetes', False), ('stroke', False)].., device:cuda:0,
  ptid:1968aa31-5fce-461a-9486-6e385a7b75e7, birthdate:1986-04-11 00:00:00, [('diabetes', False), ('stroke', False)].., device:cuda:0,
  ptid:1211c8ff-ab73-49f3-b2ab-87b7a03f6167, birthdate:1972-03-24 00:00:00, [('diabetes', False), ('stroke', False)].., device:cuda:0],
 tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0'),
 tensor([[0],
         [0],
         [0],
         [0],
         [0]], device='cuda:0'))

In [None]:
xb[0].demographics.is_pinned()

False

In [None]:
train_dl, valid_dl = get_dls(train_ds, valid_ds, bs, lazy=False)

In [None]:
x_tmp, y_tmp, m_tmp = next(iter(valid_dl))

In [None]:
x_tmp[0].demographics.is_pinned(), m_tmp[0]

(False, tensor([2], device='cuda:0'))

In [None]:
x_tmp[0]

ptid:8d1ba4bb-7250-4295-be1c-5d0d423e55f7, birthdate:1957-02-13 00:00:00, [('diabetes', True), ('stroke', False)].., device:cuda:0

In [None]:
# export
class EHRData:
    """All encompassing class for EHR data
    Holds Splits, Labels, Datasets, DataLoaders and
    provides convenience fns for training and prediction."""

    def __init__(
        self,
        path,
        labels,
        age_start,
        age_range,
        start_is_date,
        age_in_months,
        lazy_load_gpu=True,
    ):
        self.path, self.labels = path, labels
        self.age_start, self.age_range = age_start, age_range
        self.start_is_date, self.age_in_months = start_is_date, age_in_months
        self.lazy_load_gpu = lazy_load_gpu

    def load_splits(self, modality_type):
        """Load data splits given dataset path"""
        self.splits = EHRDataSplits(
            self.path,
            modality_type,
            self.age_start,
            self.age_range,
            self.start_is_date,
            self.age_in_months,
        )

    def label(self):
        """Run labeler - i.e. extract y from patient objects"""
        self.labeled = LabelEHRData(*self.splits.get_splits(), self.labels)

    def create_datasets(self, modality_type):
        """Create `EHRDataset`s"""
        self.train_ds = EHRDataset(*self.labeled.train, modality_type, self.lazy_load_gpu)
        self.valid_ds = EHRDataset(*self.labeled.valid, modality_type, self.lazy_load_gpu)
        self.test_ds = EHRDataset(*self.labeled.test, modality_type, self.lazy_load_gpu)

    def ehr_collate(b):
        """Custom collate function for use in `DataLoader`"""
        xs,ys, ms = zip(*b)
        return xs, torch.stack(ys), torch.stack(ms)

    def create_dls(self, bs, lazy, c_fn=ehr_collate, **kwargs):
        """Create `DataLoader`s"""
        self.train_dl = DataLoader(
            self.train_ds, bs, shuffle=True, collate_fn=c_fn, pin_memory=lazy, **kwargs
        )
        self.valid_dl = DataLoader(
            self.valid_ds, bs * 2, collate_fn=c_fn, pin_memory=lazy, **kwargs
        )
        self.test_dl = DataLoader(
            self.test_ds, bs * 2, collate_fn=c_fn, pin_memory=lazy, **kwargs
        )

    def _per_modality(self, modality_type, bs=64, num_workers=0):
        """Return all data per modality."""
        self.load_splits(modality_type)
        self.label()
        self.create_datasets(modality_type)
        self.create_dls(bs, self.lazy_load_gpu, num_workers=num_workers)

        pos_wts_df = self.splits.get_pos_wts(self.labels)
        pos_wts = {}
        pos_wts["train"] = torch.Tensor(pos_wts["train"].values)
        pos_wts["valid"] = torch.Tensor(pos_wts["valid"].values)
        pos_wts["test"] = torch.Tensor(pos_wts["test"].values)
        return self.train_dl, self.valid_dl, self.test_dl, pos_wts

    def get_data(self, bs=64, num_workers=0):
        """Return all data for every modality."""
        modality_types = os.listdir(f"{self.path}/processed")
        data = {}
        for m in modality_types:
            data[m] = self._per_modality(m, bs, num_workers)
        
        return data


In [None]:
show_doc(EHRData.load_splits)

<h4 id="EHRData.load_splits" class="doc_header"><code>EHRData.load_splits</code><a href="__main__.py#L22" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRData.load_splits</code>()

Load data splits given dataset path

In [None]:
show_doc(EHRData.label)

<h4 id="EHRData.label" class="doc_header"><code>EHRData.label</code><a href="__main__.py#L14" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRData.label</code>()

Run labeler - i.e. extract y from patient objects

In [None]:
show_doc(EHRData.create_datasets)

<h4 id="EHRData.create_datasets" class="doc_header"><code>EHRData.create_datasets</code><a href="__main__.py#L18" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRData.create_datasets</code>()

Create [`EHRDataset`](/lemonpie/data.html#EHRDataset)s

In [None]:
show_doc(EHRData.ehr_collate)

<h4 id="EHRData.ehr_collate" class="doc_header"><code>EHRData.ehr_collate</code><a href="__main__.py#L24" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRData.ehr_collate</code>(**`b`**)

Custom collate function for use in `DataLoader`

In [None]:
show_doc(EHRData.create_dls)

<h4 id="EHRData.create_dls" class="doc_header"><code>EHRData.create_dls</code><a href="__main__.py#L29" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRData.create_dls</code>(**`bs`**, **`lazy`**, **`c_fn`**=*`ehr_collate`*, **\*\*`kwargs`**)

Create `DataLoader`s

In [None]:
show_doc(EHRData.get_data)

<h4 id="EHRData.get_data" class="doc_header"><code>EHRData.get_data</code><a href="__main__.py#L35" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRData.get_data</code>(**`bs`**=*`64`*, **`num_workers`**=*`0`*)

Convenience function - returns everything needed for training

In [None]:
show_doc(EHRData.get_test_data)

<h4 id="EHRData.get_test_data" class="doc_header"><code>EHRData.get_test_data</code><a href="__main__.py#L47" class="source_link" style="float:right">[source]</a></h4>

> <code>EHRData.get_test_data</code>(**`bs`**=*`64`*, **`num_workers`**=*`0`*)

Convenience function - returns everything needed for prediction using test data

class MultiModalEHRData:
    def __init__(
        self,
        path,
        labels,
        age_start,
        age_range,
        start_is_date,
        age_in_months,
        lazy_load_gpu=True,
    ):


## Export -

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_basics.ipynb.
Converted 01_preprocessing_clean.ipynb.
Converted 02_preprocessing_vocab.ipynb.
Converted 03_preprocessing_transform.ipynb.
Converted 04_data.ipynb.
Converted 05_metrics.ipynb.
Converted 06_learn.ipynb.
Converted 07_models.ipynb.
Converted 08_experiment.ipynb.
Converted 999_amp_testing.ipynb.
Converted 999_fusion_clean.ipynb.
Converted 999_fusion_models.ipynb.
Converted 99_quick_walkthru.ipynb.
Converted 99_running_exps.ipynb.
Converted index.ipynb.
