In [23]:
import pandas as pd

# eicu = pd.read_csv("../data/one_hot_age_gender_region_eicu_data.csv")
eicu = pd.read_csv("../data/one_hot_age_eicu_data_2.csv")
print(eicu.shape)
print("Total hospitals: ", len(eicu.hospitalid.unique()))
print("Total drugs: ", len(eicu.columns[8:-9]))
print("Total features: ", len(eicu.columns[8:]))
eicu.head()

(63642, 1416)
Total hospitals:  164
Total drugs:  1399
Total features:  1408


Unnamed: 0.1,Unnamed: 0,patientunitstayid,hospitalid,Death,unitdischargeoffset,ventilation,sepsis,cardiovascular,2 ML - METOCLOPRAMIDE HCL 5 MG/ML IJ SOLN,3 ML VIAL : INSULIN REGULAR HUMAN 100 UNIT/ML IJ SOLN,...,SODIUM CHLORIDE BACTERIOSTATIC 0.9 % INJ SOLN,Gender,< 30,30 - 39,40 - 49,50 - 59,60 - 69,70 - 79,80 - 89,> 89
0,0,141168.0,59.0,1.0,3596.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
1,1,141194.0,73.0,0.0,4813.0,0.0,0.0,0.0,1.0,1.0,...,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
2,2,141233.0,73.0,0.0,15685.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
3,3,141244.0,73.0,0.0,3835.0,0.0,0.0,0.0,1.0,1.0,...,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
4,4,141265.0,63.0,0.0,6068.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0


In [24]:
import torch

task = "death"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [25]:
import numpy as np
from sklearn.model_selection import train_test_split

x = eicu.iloc[:, 8:]
x = np.asarray(x)
x = x.astype(np.float32)
x = torch.Tensor(x).to(device)
print(x.shape)

x_train, x_test = train_test_split(x, test_size=0.3)
print(x_train.shape)
print(x_test.shape)

torch.Size([63642, 1408])
torch.Size([44549, 1408])
torch.Size([19093, 1408])


In [26]:
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
import torch.nn as nn


class EvaluationUtils:
    """ Utility class to evaluate models """

    @staticmethod
    def mean_bce(pred, y, reduction='mean'):
        criterion = nn.BCELoss(reduction=reduction)
        return criterion(pred, y)

    @staticmethod
    def mean_ce(pred, y, reduction='mean'):
        criterion = nn.CrossEntropyLoss(reduction=reduction)
        return criterion(pred, y)

    @staticmethod
    def mean_mse(pred, y, reduction='mean'):
        criterion = nn.MSELoss(reduction=reduction)
        return criterion(pred, y)

    @staticmethod
    def mean_accuracy(pred, y):
        pred = (pred > 0.5).float()
        return ((pred - y).abs() < 1e-2).float().mean()

    @staticmethod
    def mean_roc_auc(pred, y):
        y = y.detach().cpu().numpy()
        pred = pred.detach().cpu().numpy()
        return roc_auc_score(y, pred)

    @staticmethod
    def mean_pr_auc(pred, y):
        y = y.detach().cpu().numpy()
        pred = pred.detach().cpu().numpy()
        precision, recall, _ = precision_recall_curve(y, pred)
        return auc(recall, precision)

In [28]:
# Data
DATASET_DIR_CONFIG_KEY = "dataset_dir_path"
DATASET_DIR_DEFAULT = "../data/"
EICU_PATH = "one_hot_age_gender_eicu_data.csv"

# FL
LR = 0.0001
WEIGHT_DECAY = 0.001
MIN_HOSPITAL_DEATH_COUNT = 150
TOTAL_FL_ROUND = 1501
LOCAL_FL_EPOCHS = 1
TEST_SIZE = 0.3
VAL_SIZE = 0.2
TOTAL_FEATURE = 1420
FL_HIDDEN_LAYER_UNITS = "256,128,64"
BIAS_INIT_PRIOR_PROB = None
BATCH_SIZE = 128

# MADE
MADE_EPOCHS = 50
MADE_HIDDEN_LAYER_UNITS = "500"
MADE_NUM_MASKS = 1
MADE_SAMPLES = 1
MADE_RESAMPLE_EVERY = 20

# Weight
REWEIGHT_LAMBDA = 1.0

# Task
VENTILATOR = "ventilator"  # Ventilator usage in hospitals
SEPSIS = "sepsis"  # Sepsis diagnosis in hospitals
DEATH = "death"  # Death prediction in hospitals
LENGTH = "length"  # Length of stay in hospitals
CARDIOVASCULAR = "cardiovascular" # Cardiovascular
REGION_VENTILATOR = "region_ventilator"  # Ventilator usage in regions
REGION_SEPSIS = "region_sepsis"  # Sepsis diagnosis in regions
REGION_DEATH = "region_death"  # Death prediction in regions
REGION_LENGTH = "region_length"  # Length of stay in regions
REGION_TASK_PREFIX = "region"
SIMULATION = "simulation"  # Simulation dataset
SIMULATION_BY_DIR = "simulation_by_dir"  # Simulation dataset load by directory
COLOR_MNIST = "color_mnist"  # Color MNIST
BINARIZED_MNIST = "binarized_mnist"  # Binarized MNIST

# Algorithm
WEIGHTED = "weighted"  # Weighted
UNWEIGHTED = "unweighted"  # Unweighted
BOTH = "both"  # Both, used by paired

# Stage
GRID_SEARCH = "grid_search"
RETRAIN = "retrain"

# Label index
LABEL_IDX = {
    VENTILATOR: 5,
    REGION_VENTILATOR: 5,
    SEPSIS: 6,
    REGION_SEPSIS: 6,
    CARDIOVASCULAR: 7,
    DEATH: 3,
    REGION_DEATH: 3,
    LENGTH: 4,
    REGION_LENGTH: 4,
    COLOR_MNIST: 0
}

# Log
LOG_PATH_CONFIG_KEY = "log_dir_path"
LOG_PATH_DEFAULT = "../log/"
LOGGER_DEFAULT = "logger_default"
LOGGER_MADE = "logger_made"

# Output
OUTPUT_PATH_CONFIG_KEY = "output_dir_path"
OUTPUT_PATH_DEFAULT = "../output/"

# Simulate
SIMULATE_SOURCE_HOSPITAL_ID = 420
SIMULATE_TARGET_HOSPITAL_ID = 449

SIMULATE_X_SOURCE_PATH = "simulate_x_source.csv"
SIMULATE_Y_SOURCE_PATH = "simulate_y_source.csv"
SIMULATE_X_TARGET_PATH = "simulate_x_target.csv"
SIMULATE_Y_TARGET_PATH = "simulate_y_target.csv"


SIMULATE_DATA_DIR = "simulation/"

TARGET_HOSPITAL_ID = "Northeast"
TOTAL_SEED = 10

FED_WEIGHT_METHOD_SGD = "fed_weight_method_sgd"
FED_WEIGHT_METHOD_AVG = "fed_weight_method_avg"

# Auto-encoder

In [42]:
class Autoencoder(nn.Module):
    
    def __init__(self, in_features: int,
                 hidden_sizes: str):
        super(Autoencoder, self).__init__()
        hidden_list = list(map(int, hidden_sizes.split(',')))

        # Encoder layers
        self._encoder = []
        encoder_hiddens = [in_features] + hidden_list
        for h0, h1 in zip(encoder_hiddens, encoder_hiddens[1:]):
            self._encoder.extend([
                nn.Linear(h0, h1),
                nn.ReLU(),
            ])
        self._encoder = nn.Sequential(*self._encoder)
        
        # Decoder layers
        self._decoder = []
        decoder_hiddens = encoder_hiddens[::-1]
        for h0, h1 in zip(decoder_hiddens, decoder_hiddens[1:]):
            self._decoder.extend([
                nn.Linear(h0, h1),
                nn.ReLU(),
            ])
        self._decoder.pop()  # pop the last ReLU for the output layer
        self._decoder = nn.Sequential(*self._decoder)

    def forward(self, x):
        encoded = self._encoder(x)
        decoded = self._decoder(encoded)
        return decoded

In [43]:
import torch
from torch.autograd import Variable
import math

class AutoEncoderService:

    def __init__(self, task: str) -> None:
        self._task = task

    def run_auto_encoder(self,
                         model, opt,
                         x_train, x_test,
                         split, batch_size, device):

        torch.set_grad_enabled(split == 'train')
        model.train() if split == 'train' else model.eval()
        x = x_train if split == 'train' else x_test

        if batch_size <= 0 or batch_size > len(x):
            raise ValueError(
                "Batch size must be larger than 0 and smaller than sample size")

        N, D = x.size()
        B = 64  # batch size
        nsteps = math.ceil(N/B)

        loss_for_samples = torch.full((N,), torch.nan).to(device)  # N x 1
        loss_total = []

        total_samples = 0
        for step in range(nsteps):
            # fetch the next batch of data
            xb = Variable(x[step * B: step * B + B])
            self._run_batch(step, B, N, model, opt, xb, split,
                            loss_total, loss_for_samples)
            total_samples += B

        if total_samples < N:
            # fetch the remaining data
            xb = Variable(x[total_samples:])
            self._run_batch(step + 1, B, N, model, opt, xb, split,
                            loss_total, loss_for_samples)

        assert not torch.isnan(loss_for_samples).any()
        return sum(loss_total) / len(loss_total), loss_for_samples
    
    def _run_batch(self, step, B, N,
                   model, opt,
                   xb, split,
                   loss_total, 
                   loss_for_samples):

        pred = model(xb)

        if self._task == COLOR_MNIST:
            # Gaussian
            loss = EvaluationUtils.mean_mse(pred, xb, reduction='mean')  # batch_size x D

        else:
            # Multinomial
            loss = EvaluationUtils.mean_ce(pred, xb, reduction='mean')  # batch_size x 1

        loss_total.append(loss.item())
        # probs_sample = torch.exp(-1 * loss_sample)
        # if step * B + B > N:
        #     loss_for_samples[step * B:] = loss_sample
        # else:
        #     loss_for_samples[step * B: step * B + B] = loss_sample

        # backward/update
        if split == 'train':
            opt.zero_grad()
            loss.backward()
            opt.step()

In [44]:
ae_epochs = 150
ae_hiddens = "1024,512,256"
ae_learning_rate = 0.0001
ae_weight_decay = 0.001
ae_batch_size = 64

In [45]:
import torch

ae = Autoencoder(x.size(1), ae_hiddens)
ae_service = AutoEncoderService(task)
ae_opt = torch.optim.Adam(ae.parameters(),
                            lr=ae_learning_rate,
                            weight_decay=ae_weight_decay)
ae_scheduler = torch.optim.lr_scheduler.StepLR(
    ae_opt, step_size=45, gamma=0.1)

ae_train_hist = []

for epoch in range(ae_epochs):
    train_loss, _ = ae_service.run_auto_encoder(ae, ae_opt,
                         x_train, None,
                         'train', ae_batch_size,
                         device)
    
    ae_scheduler.step()
    ae_train_hist.append(train_loss)

    print("epoch: {}, train loss: {}".format(epoch, train_loss))

test_loss, _ = ae_service.run_auto_encoder(ae, ae_opt,
                                           None, x_test,
                                           'test', ae_batch_size,
                                           device)

KeyboardInterrupt: 

# Multinomial MADE

In [10]:
made_hiddens = "1408"
num_masks = 1
natural_ordering = False
made_learning_rate = 0.0001
made_weight_decay = 0.001
made_epochs = 50
made_batch_size = 64
made_samples = 10
resample_every = 20

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class MaskedLinear(nn.Linear):
    """ same as Linear except has a configurable mask on the weights """

    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias)
        self.register_buffer('mask', torch.ones(out_features, in_features))

    def set_mask(self, mask):
        self.mask.data.copy_(torch.from_numpy(mask.astype(np.uint8).T))

    def forward(self, input):
        return F.linear(input, self.mask * self.weight, self.bias)


class MADE(nn.Module):

    def __init__(self, nin, hidden_sizes, nout, num_masks=1, natural_ordering=False):
        """
        nin: integer; number of inputs
        hidden sizes: a list of integers; number of units in hidden layers
        nout: integer; number of outputs, which usually collectively parameterize some kind of 1D distribution
              note: if nout is e.g. 2x larger than nin (perhaps the mean and std), then the first nin
              will be all the means and the second nin will be stds. i.e. output dimensions depend on the
              same input dimensions in "chunks" and should be carefully decoded downstream appropriately.
              the output of running the tests for this file makes this a bit more clear with examples.
        num_masks: can be used to train ensemble over orderings/connections
        natural_ordering: force natural ordering of dimensions, don't use random permutations
        """

        super().__init__()
        self.nin = nin
        self.nout = nout
        self.hidden_sizes = hidden_sizes
        assert self.nout % self.nin == 0, "nout must be integer multiple of nin"

        # define a simple MLP neural net
        self.net = []
        hs = [nin] + hidden_sizes + [nout]
        for h0, h1 in zip(hs, hs[1:]):
            self.net.extend([
                MaskedLinear(h0, h1),
                nn.ReLU(),
            ])
        self.net.pop()  # pop the last ReLU for the output layer
        self.net = nn.Sequential(*self.net)

        # seeds for orders/connectivities of the model ensemble
        self.natural_ordering = natural_ordering
        self.num_masks = num_masks
        self.seed = 0  # for cycling through num_masks orderings

        self.m = {}
        self.update_masks()  # builds the initial self.m connectivity
        # note, we could also precompute the masks and cache them, but this
        # could get memory expensive for large number of masks.

    def update_masks(self):
        
        L = len(self.hidden_sizes)

        # fetch the next seed and construct a random stream
        rng = np.random.RandomState(self.seed)
        self.seed = (self.seed + 1) % self.num_masks

        # sample the order of the inputs and the connectivity of all neurons
        self.m[-1] = np.arange(
            self.nin) if self.natural_ordering else rng.permutation(self.nin)
        for l in range(L):
            self.m[l] = rng.randint(
                self.m[l - 1].min(), self.nin - 1, size=self.hidden_sizes[l])

        # construct the mask matrices
        masks = [self.m[l - 1][:, None] <= self.m[l][None, :]
                 for l in range(L)]
        masks.append(self.m[L - 1][:, None] < self.m[-1][None, :])

        # handle the case where nout = nin * k, for integer k > 1
        if self.nout > self.nin:
            k = int(self.nout / self.nin)
            # replicate the mask across the other outputs
            masks[-1] = np.concatenate([masks[-1]] * k, axis=1)

        # set the masks in all MaskedLinear layers
        layers = [l for l in self.net.modules() if isinstance(l, MaskedLinear)]
        for l, m in zip(layers, masks):
            l.set_mask(m)

    def forward(self, x):
        return self.net(x)

In [17]:
import torch
from torch.autograd import Variable
import math


class MadeService:

    def __init__(self, task: str) -> None:
        self._task = task

    def run_made(self,
                 model, opt,
                 x_train, x_test,
                 split, batch_size,
                 samples, resample_every,
                 device):

        # enable/disable grad for efficiency of forwarding test batches
        torch.set_grad_enabled(split == 'train')
        model.train() if split == 'train' else model.eval()
        nsamples = 1 if split == 'train' else samples
        x = x_train if split == 'train' else x_test

        if batch_size <= 0 or batch_size > len(x):
            raise ValueError(
                "Batch size must be larger than 0 and smaller than sample size")

        N, D = x.size()
        B = 64  # batch size
        nsteps = math.ceil(N/B)

        loss_for_samples = torch.full((N,), torch.nan).to(device)  # N x 1
        loss_total = []

        total_samples = 0
        for step in range(nsteps):
            # fetch the next batch of data
            xb = Variable(x[step * B: step * B + B])
            self._run_batch(step, B, N, model, opt, xb, split,
                            nsamples, resample_every, loss_total, loss_for_samples)
            total_samples += B

        if total_samples < N:
            # fetch the remaining data
            xb = Variable(x[total_samples:])
            self._run_batch(step + 1, B, N, model, opt, xb, split,
                            nsamples, resample_every, loss_total, loss_for_samples)

        assert not torch.isnan(loss_for_samples).any()
        return sum(loss_total) / len(loss_total), loss_for_samples

    def _run_batch(self, step, B, N,
                   model, opt,
                   xb, split,
                   nsamples, resample_every,
                   loss_total, loss_for_samples):

        # get the logits, potentially run the same batch a number of times, resampling each time
        xbhat = torch.zeros_like(xb)
        for _ in range(nsamples):
            # perform order/connectivity-agnostic training by resampling the masks
            if step % resample_every == 0 or split == 'test':  # if in test, cycle masks every time
                model.update_masks()
            # forward the model
            xbhat += model(xb)
        xbhat /= nsamples

        # evaluate the binary cross entropy loss
        if self._task == COLOR_MNIST:
            # Gaussian
            pred = xbhat
            loss_each = EvaluationUtils.mean_mse(pred, xb,
                                                 reduction='none')  # batch_size x D
            loss_sample = torch.mean(loss_each, dim=1)  # batch_size x 1
            loss_mean = EvaluationUtils.mean_mse(pred, xb)  # 1 x 1
        
        # elif self._task == BINARIZED_MNIST:
        else:

            # Binary
            # pred = torch.sigmoid(xbhat)
            # loss_each = EvaluationUtils.mean_bce(pred, xb,
            #                                      reduction='none')  # batch_size x D
            # loss_sample = torch.mean(loss_each, dim=1)  # batch_size x 1
            # loss_mean = EvaluationUtils.mean_bce(pred, xb)  # 1 x 1
        
        # else:
            # Multinomial
            pred = xbhat
            loss_sample = EvaluationUtils.mean_ce(pred, xb,
                                                  reduction='none')  # batch_size x 1
            num_drugs_taken = torch.sum(xb, dim=1) # batch_size x 1
            loss_sample = loss_sample / num_drugs_taken # batch_size x 1
            loss_mean = torch.mean(loss_sample)  # 1 x 1

        loss_total.append(loss_mean.item())
        # probs_sample = torch.exp(-1 * loss_sample)
        if step * B + B > N:
            loss_for_samples[step * B:] = loss_sample
        else:
            loss_for_samples[step * B: step * B + B] = loss_sample

        # backward/update
        if split == 'train':
            opt.zero_grad()
            loss_mean.backward()
            opt.step()

In [18]:
import torch

hidden_list = list(map(int, made_hiddens.split(',')))
made = MADE(x.size(1), hidden_list,
            x.size(1), num_masks=num_masks,
            natural_ordering=natural_ordering)
made_service = MadeService(task)
made_opt = torch.optim.Adam(made.parameters(),
                            lr=made_learning_rate,
                            weight_decay=made_weight_decay)
made_scheduler = torch.optim.lr_scheduler.StepLR(
    made_opt, step_size=45, gamma=0.1)

made_train_hist = []

for epoch in range(made_epochs):
    made_train_loss, _ = made_service.run_made(made, made_opt,
                                               x_train, None,
                                               'train', made_batch_size,
                                               made_samples, resample_every, device)
    
    made_scheduler.step()
    made_train_hist.append(made_train_loss)

test_loss, _ = made_service.run_made(made, made_opt,
                                     None, x_test,
                                     'test', made_batch_size,
                                     made_samples, resample_every, device)

KeyboardInterrupt: 

# VAE

In [None]:
vae_epochs: 150
vae_latent_dim: 8
vae_hiddens: "1024,512,256"
vae_learning_rate: 0.0001
vae_weight_decay: 0.001
vae_batch_size = 64

In [None]:
import torch
import torch.nn as nn
from typing import Tuple


class VAE(nn.Module):

    """ Initialize """

    def __init__(self, in_features: int,
                 hidden_sizes: str,
                 latent_dim: int) -> None:

        super(VAE, self).__init__()

        hidden_list = list(map(int, hidden_sizes.split(',')))

        # Encoder layers
        self._encoder = []
        encoder_hiddens = [in_features] + hidden_list
        for h0, h1 in zip(encoder_hiddens, encoder_hiddens[1:]):
            self._encoder.extend([
                nn.Linear(h0, h1),
                nn.ReLU(),
            ])
        self._encoder = nn.Sequential(*self._encoder)

        # Latent space layers
        self._mu_layer = nn.Linear(encoder_hiddens[-1], latent_dim)
        self._logvar_layer = nn.Linear(encoder_hiddens[-1], latent_dim)

        # Decoder layers
        self._decoder = []
        # decoder_hiddens = [latent_dim] + encoder_hiddens[::-1]

        decoder_hiddens = [latent_dim, 64, 1408]

        for h0, h1 in zip(decoder_hiddens, decoder_hiddens[1:]):
            self._decoder.extend([
                nn.Linear(h0, h1),
                nn.Dropout(0.3),
                nn.ReLU(),
            ])
        self._decoder.pop()  # pop the last ReLU for the output layer
        self._decoder.pop()  # pop the last Dropout for the output layer
        self._decoder = nn.Sequential(*self._decoder)

        print("Encoder: {}".format(self._encoder))
        print("Mu layer: {}".format(self._mu_layer))
        print("Log var layer: {}".format(self._logvar_layer))
        print("Decoder: {}".format(self._decoder))

    """ Public method """

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mu, logvar = self._encode(x)
        z = self._reparameterize(mu, logvar)
        reconstructed = self._decode(z)
        return reconstructed, mu, logvar

    """ Private method """

    def _encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        hidden = self._encoder(x)
        mu = self._mu_layer(hidden)
        logvar = self._logvar_layer(hidden)
        return mu, logvar

    def _reparameterize(self, mu: torch.Tensor,
                        logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def _decode(self, z: torch.Tensor) -> torch.Tensor:
        reconstructed = self._decoder(z)
        return reconstructed

In [None]:
import torch
from torch.autograd import Variable
import math


class VaeService:

    def __init__(self, task: str) -> None:
        self._task = task

    def run_vae(self,
                model, opt,
                x_train, x_test,
                split, batch_size,
                beta, device):

        torch.set_grad_enabled(split == 'train')
        model.train() if split == 'train' else model.eval()
        x = x_train if split == 'train' else x_test

        if batch_size <= 0 or batch_size > len(x):
            raise ValueError(
                "Batch size must be larger than 0 and smaller than sample size")

        N, D = x.size()
        B = 64  # batch size
        nsteps = math.ceil(N/B)

        loss_for_samples = torch.full((N,), torch.nan).to(device)  # N x 1
        loss_total = []

        total_samples = 0
        for step in range(nsteps):
            # fetch the next batch of data
            xb = Variable(x[step * B: step * B + B])
            self._run_batch(step, B, N, model, opt, xb, split, beta,
                            loss_total, loss_for_samples)
            total_samples += B

        if total_samples < N:
            # fetch the remaining data
            xb = Variable(x[total_samples:])
            self._run_batch(step + 1, B, N, model, opt, xb, split, beta,
                            loss_total, loss_for_samples)

        assert not torch.isnan(loss_for_samples).any()
        return sum(loss_total) / len(loss_total), loss_for_samples

    def _run_batch(self, step, B, N,
                   model, opt,
                   xb, split, beta,
                   loss_total, loss_for_samples):

        pred, mu, logvar = model(xb)

        if self._task == COLOR_MNIST:
            # Gaussian
            loss_each = EvaluationUtils.mean_mse(pred, xb,
                                                 reduction='none')  # batch_size x D
            loss_sample = torch.mean(loss_each, dim=1)  # batch_size x 1
            kl_sample = -0.5 * \
                torch.sum(1 + logvar - mu.pow(2) -
                          logvar.exp(), dim=1)  # batch_size x 1
            
            loss_vae = torch.sum(loss_sample + beta * kl_sample)  # 1 x 1

            loss_sample += kl_sample  # batch_size x 1

        else:
            # Multinomial
            loss_sample = EvaluationUtils.mean_ce(pred, xb,
                                                  reduction='none')  # batch_size x 1
            num_drugs_taken = torch.sum(xb, dim=1)  # batch_size x 1
            loss_sample = loss_sample / num_drugs_taken # batch_size x 1
            
            # batch_size x 1
            kl_sample = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)  # batch_size x 1

            loss_vae = torch.sum(loss_sample + beta * kl_sample) # 1 x 1

            loss_sample += kl_sample  # batch_size x 1

        loss_total.append(loss_vae.item())
        # probs_sample = torch.exp(-1 * loss_sample)
        if step * B + B > N:
            loss_for_samples[step * B:] = loss_sample
        else:
            loss_for_samples[step * B: step * B + B] = loss_sample

        # backward/update
        if split == 'train':
            opt.zero_grad()
            loss_vae.backward()
            opt.step()

In [None]:
# Train VAE
vae = VAE(x.size(1),
          vae_hiddens,
          vae_latent_dim)
vae.to(device)
vae_opt = torch.optim.Adam(vae.parameters(),
                            lr=vae_learning_rate,
                            weight_decay=vae_weight_decay)
vae_scheduler = torch.optim.lr_scheduler.StepLR(
    vae_opt, step_size=45, gamma=0.1)

vae_service = VaeService(task)

vae_train_hist = []

for epoch in range(vae_epochs):
    beta = 0.002 * epoch
    beta = beta if beta < 1.0 else 1.0
    vae_train_loss, _ = vae_service.run_vae(vae, vae_opt,
                                            x_train, None,
                                            'train', vae_batch_size,
                                            beta, device)

    vae_scheduler.step()
    vae_train_hist.append(vae_train_loss)

vae_test_loss, _ = vae_service.run_vae(vae, vae_opt,
                                       None, x_test,
                                       'test', vae_batch_size,
                                       beta, device)

# VQ-VAE

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class VectorQuantizer(nn.Module):

    def __init__(self, num_embeddings: int,
                 embedding_dim: int,
                 beta: float,
                 device: torch.device) -> None:

        super(VectorQuantizer, self).__init__()

        self._num_embeddings = num_embeddings  # K
        self._embedding_dim = embedding_dim  # D
        self._beta = beta  # Commitment loss coefficient
        self._decay = 0.99
        self._epsilon = 1e-5
        self.register_buffer('_ema_cluster_size',
                             torch.zeros(num_embeddings))  # 1 x K
        self._ema_w = nn.Parameter(torch.Tensor(
            num_embeddings, self._embedding_dim))  # K x D
        self._ema_w.data.normal_()
        self._device = device

        # Codebook
        self._embedding = nn.Embedding(self._num_embeddings,  # K x D
                                       self._embedding_dim)
        self._embedding.weight.data.uniform_(-1 / self._num_embeddings,  # 1/K to make sure the integral of PDF is 1
                                             1 / self._num_embeddings)

    def forward(self, inputs: torch.Tensor):

        # Euclidean distances
        # print(inputs.shape)  # N x D

        distances = (torch.sum(inputs**2, dim=1, keepdim=True)  # N x K
                     + torch.sum(self._embedding.weight**2, dim=1)
                     - 2 * torch.matmul(inputs, self._embedding.weight.t()))

        # Latent representation z = q(z|x) e.g. [0 0 1 0 0]
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)  # N x 1
        latent = torch.zeros(encoding_indices.shape[0],  # N x K
                             self._num_embeddings,
                             device=self._device)
        latent.scatter_(1, encoding_indices, 1)

        # Codeword
        codeword = torch.matmul(latent, self._embedding.weight)  # N x D

        # EMA
        if self.training:
            # torch.sum(latent, 0): observed cluster size per codeword (1 x K)
            # self._ema_cluster_size: smoothed cluster size per codeword
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                (1 - self._decay) * torch.sum(latent, 0)  # 1 x K

            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n)

            # Row of latent.T: Binary vector for codeword neighbours
            # 1: the encoder output is close to codeword k (inside cluster of centroid k)
            # 0: the encoder output is not close to codeword k
            # Column of inputs: Encoder output by different samples
            # dw: Sum of encoder outputs close to codeword k
            # self._ema_w: Smoothed sum of encoder outputs close to codeword k
            dw = latent.T @ inputs  # K x D
            self._ema_w = nn.Parameter(
                self._ema_w * self._decay + (1 - self._decay) * dw)
            self._embedding.weight = nn.Parameter(
                self._ema_w / self._ema_cluster_size.unsqueeze(1))

        # Loss
        # Freeze encoder output, train codebook to make sure codeword close to encoder output
        # codebook_loss = F.mse_loss(codeword, inputs.detach()) # No need if using EMA!
        # Freeze codeword, train encoder to make sure encoder output close to codeword
        commitment_loss = F.mse_loss(codeword.detach(), inputs)
        # loss = codebook_loss + self._beta * commitment_loss # No need if using EMA!
        loss = self._beta * commitment_loss

        # Copy gradients of codeword to inputs
        codeword = inputs + (codeword - inputs).detach()

        return loss, codeword


class VQVAE(nn.Module):

    def __init__(self, in_features: int,
                 hidden_sizes: str,
                 latent_dim: int,
                 beta: float,
                 device: torch.device) -> None:

        super(VQVAE, self).__init__()

        hidden_list = list(map(int, hidden_sizes.split(',')))

        # Encoder layers
        self._encoder = []
        encoder_hiddens = [in_features] + hidden_list
        for h0, h1 in zip(encoder_hiddens, encoder_hiddens[1:]):
            self._encoder.extend([
                nn.Linear(h0, h1),
                nn.ReLU(),
            ])
        self._encoder = nn.Sequential(*self._encoder)

        self._vq = VectorQuantizer(latent_dim,
                                   encoder_hiddens[-1],
                                   beta, device)

        # Decoder layers
        self._decoder = []
        decoder_hiddens = encoder_hiddens[::-1]
        for h0, h1 in zip(decoder_hiddens, decoder_hiddens[1:]):
            self._decoder.extend([
                nn.Linear(h0, h1),
                nn.ReLU(),
            ])
        self._decoder.pop()  # pop the last ReLU for the output layer
        self._decoder = nn.Sequential(*self._decoder)

    def forward(self, x: torch.Tensor):

        z = self._encoder(x)
        loss, codeword = self._vq(z)
        reconstructed = self._decoder(codeword)

        return reconstructed, loss

In [None]:
import torch
from torch.autograd import Variable
import math


class VqVaeService:

    def __init__(self, task: str) -> None:
        super(VqVaeService, self).__init__()
        self._task = task

    def run_vqvae(self,
                  model, opt,
                  x_train, x_test,
                  split, batch_size,
                  device):

        torch.set_grad_enabled(split == 'train')
        model.train() if split == 'train' else model.eval()
        x = x_train if split == 'train' else x_test

        if batch_size <= 0 or batch_size > len(x):
            raise ValueError(
                "Batch size must be larger than 0 and smaller than sample size")

        N, D = x.size()
        B = 64  # batch size
        nsteps = math.ceil(N/B)

        loss_for_samples = torch.full((N,), torch.nan).to(device)  # N x 1
        loss_total = []

        total_samples = 0
        for step in range(nsteps):
            # fetch the next batch of data
            xb = Variable(x[step * B: step * B + B])
            self._run_batch(step, B, N, model, opt, xb, split,
                            loss_total, loss_for_samples)
            total_samples += B

        if total_samples < N:
            # fetch the remaining data
            xb = Variable(x[total_samples:])
            self._run_batch(step + 1, B, N, model, opt, xb, split,
                            loss_total, loss_for_samples)

        assert not torch.isnan(loss_for_samples).any()
        return sum(loss_total) / len(loss_total), loss_for_samples

    def _run_batch(self, step, B, N,
                   model, opt,
                   xb, split,
                   loss_total,
                   loss_for_samples):

        pred, vq_loss = model(xb)

        if self._task == COLOR_MNIST:
            # Gaussian
            loss_each = EvaluationUtils.mean_mse(pred, xb,
                                                 reduction='none')  # batch_size x D
            loss_sample = torch.mean(loss_each, dim=1)  # batch_size x 1

            reconstruction_loss = EvaluationUtils.mean_mse(pred, xb,
                                                           reduction='sum')
            loss_vae = reconstruction_loss + vq_loss  # 1 x 1

        else:
            # Multinomial
            loss_sample = EvaluationUtils.mean_ce(pred, xb,
                                                  reduction='none')  # batch_size x 1
            num_drugs_taken = torch.sum(xb, dim=1)  # batch_size x 1
            loss_sample = loss_sample / num_drugs_taken  # batch_size x 1

            reconstruction_loss = EvaluationUtils.mean_ce(pred, xb,
                                                          reduction='sum')
            loss_vae = reconstruction_loss + vq_loss  # 1 x 1

        loss_total.append(loss_vae.item())
        # probs_sample = torch.exp(-1 * loss_sample)
        if step * B + B > N:
            loss_for_samples[step * B:] = loss_sample
        else:
            loss_for_samples[step * B: step * B + B] = loss_sample

        # backward/update
        if split == 'train':
            opt.zero_grad()
            loss_vae.backward()
            opt.step()


In [None]:
vqvae = VQVAE(x.size(1), vae_hiddens, vae_latent_dim, 0.25, device)

vqvae.to(device)
vqvae_opt = torch.optim.Adam(vqvae.parameters(),
                             lr=vae_learning_rate,
                             weight_decay=vae_weight_decay)
vqvae_scheduler = torch.optim.lr_scheduler.StepLR(
    vqvae_opt, step_size=45, gamma=0.1)

vqvae_service = VqVaeService(task)

vqvae_train_hist = []
for epoch in range(vae_epochs):
    vqvae_train_loss, _ = vqvae_service.run_vqvae(vqvae, vqvae_opt,
                                                  x_train, None,
                                                  'train', vae_batch_size,
                                                  device)

    vqvae_scheduler.step()
    vqvae_train_hist.append(vqvae_train_loss)

vqvae_test_loss, _ = vqvae_service.run_vqvae(vqvae, vqvae_opt,
                                             x_test, None,
                                             'test', vae_batch_size,
                                             device)