## Test torch functions

In [1]:
# check current working env
import os
print(os.environ['CONDA_DEFAULT_ENV'])

mysoden


In [3]:
import pandas as pd
import numpy as np
import torch

In [4]:
y = torch.tensor([1, 2, 3, 4, 5, 6])
y

tensor([1, 2, 3, 4, 5, 6])

In [16]:
# view: Returns a new tensor with the same data as the self tensor but of a different shape.
Lambda_t = y.index_select(-1, torch.tensor([0])).view(-1, 1) ## retrieve Lambda_t from y, returns as a 2-D tensor of 1 element
Lambda_t

tensor([[1]])

In [17]:
T = y.index_select(-1, torch.tensor([1])).view(-1, 1)  ## retrieve the final time step T from y, returns as a 2-D tensor of 1 element
T

tensor([[2]])

In [18]:
x = y.index_select(-1, torch.tensor(range(2, y.size(-1))))
x

tensor([3, 4, 5, 6])

In [19]:
torch.tensor(x, dtype=torch.long)

  torch.tensor(x, dtype=torch.long)


tensor([3, 4, 5, 6])

In [21]:
x.to(torch.long)

tensor([3, 4, 5, 6])

In [43]:
import torch.nn as nn
feature_size = 4

In [44]:
embed = nn.Embedding(1602, feature_size)
embed(x.to(torch.long))

tensor([[-0.4078, -0.7414, -1.5151, -0.7076],
        [-0.2540,  0.5423, -0.4350,  1.1568],
        [-0.7054, -0.3997,  0.3515, -0.0304],
        [ 2.1536,  0.7084, -0.3678, -1.0363]], grad_fn=<EmbeddingBackward0>)

In [45]:
t = torch.tensor(5)
t

tensor(5)

In [47]:
# concatenate Lambda_t, t*T, and x into a 1-D tensor
inp = torch.cat(
            [Lambda_t,
             t.repeat(T.size()) * T,  # s = t * T
             x.view(-1, feature_size)], dim=1)
inp

tensor([[ 1, 10,  3,  4,  5,  6]])

In [55]:
inp.size()

torch.Size([1, 6])

In [49]:
T.size()

torch.Size([1, 1])

In [50]:
t.repeat(T.size()) # Repeats the tensor t along the specified dimensions.

tensor([[5]])

In [51]:
t.repeat(T.size()) * T

tensor([[10]])

In [52]:
T

tensor([[2]])

In [53]:
x

tensor([3, 4, 5, 6])

In [54]:
## make a complete nn model with specified layers and sizes
def make_net(input_size, hidden_size, num_layers, output_size, dropout=0,
             batch_norm=False, act="relu", softplus=True):
    if act == "selu":
        ActFn = nn.SELU
    else:
        ActFn = nn.ReLU
    modules = [nn.Linear(input_size, hidden_size), ActFn()]   ## Applies a linear transformation to the incoming data
    if batch_norm:
        modules.append(nn.BatchNorm1d(hidden_size))
    if dropout > 0:
        modules.append(nn.Dropout(p=dropout))
    if num_layers > 1:
        for _ in range(num_layers - 1):
            modules.append(nn.Linear(hidden_size, hidden_size))
            modules.append(ActFn())
            if batch_norm:
                modules.append(nn.BatchNorm1d(hidden_size))
            if dropout > 0:
                modules.append(nn.Dropout(p=dropout))
    modules.append(nn.Linear(hidden_size, output_size))
    if softplus:  # ODE models
        modules.append(nn.Softplus())
    return nn.Sequential(*modules)

In [57]:
net = make_net(6, 3, num_layers = 2, output_size = 1, dropout=0,
             batch_norm=False, act="relu", softplus=True)

In [67]:
# solve dtype error
inp.dtype
inp = inp.to(torch.float32)

In [64]:
net(inp)

tensor([[0.4979]], grad_fn=<SoftplusBackward0>)

In [66]:
output = net(inp) * T
output

tensor([[0.9958]], grad_fn=<MulBackward0>)

In [68]:
x.view(-1, 1)

tensor([[3],
        [4],
        [5],
        [6]])

## Test foward function in NonCoxFuncModel

In [75]:
# define a sample input
inputs = {
  "t": torch.tensor(5),
  "init_cond": torch.tensor(1),
  "features": torch.tensor([[3, 4, 5, 6]])
}

In [77]:
# retrieve information from input
t = inputs["t"]
print("t: ", t)
init_cond = inputs["init_cond"]
print("init_cond: ", init_cond)
features = inputs["features"]
print("features: ", features)
init_cond = torch.cat([init_cond.view(-1, 1), t.view(-1, 1), features], dim=1)  ## rearrange; equiv to c(init_cond, t, features)
t = torch.tensor([0., 1.])
print("new t: ", t)

t:  tensor(5)
init_cond:  tensor(1)
features:  tensor([[3, 4, 5, 6]])
new t:  tensor([0., 1.])


In [73]:
init_cond.view(-1, 1)

tensor([[1]])

In [74]:
t.view(-1, 1)

tensor([[5]])

In [78]:
init_cond

tensor([[1, 5, 3, 4, 5, 6]])

In [79]:
class BaseSurvODEFunc(nn.Module):
    def __init__(self):
        super(BaseSurvODEFunc, self).__init__()
        self.nfe = 0
        self.batch_time_mode = False

    def set_batch_time_mode(self, mode=True):
        self.batch_time_mode = mode
        # `odeint` requires the output of `odefunc` to have the same size as
        # `init_cond` despite the how many steps we are going to evaluate. Set
        # `self.batch_time_mode` to `False` before calling `odeint`. However,
        # when we want to call the forward function of `odefunc` directly and
        # when we would like to evaluate multiple time steps at the same time,
        # set `self.batch_time_mode` to `True` and the output will have size
        # (len(t), size(y)).

    ## What is nfe??
    def reset_nfe(self):
        self.nfe = 0

    def forward(self, t, y):
        raise NotImplementedError("Not implemented.")

In [80]:
class ContextRecMLPODEFunc(BaseSurvODEFunc):
    def __init__(self, feature_size, hidden_size, num_layers, batch_norm=False,
                 use_embed=True):
        super(ContextRecMLPODEFunc, self).__init__()
        self.feature_size = feature_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.batch_norm = batch_norm
        self.use_embed = use_embed
        if use_embed:
            self.embed = nn.Embedding(1602, self.feature_size)  ## A simple lookup table that maps index value to a weighted matrix of certain dimension
        else:
            self.embed = None
        self.net = make_net(input_size=feature_size+2, hidden_size=hidden_size,
                            num_layers=num_layers, output_size=1,
                            batch_norm=batch_norm)
    
    ## ---where does the y come from?--- passed within odeint as init_cond?
    ## forward propagetion; one step forward
    ## outputs a number from nn
    def forward(self, t, y):
        """
        Arguments:
          t: When self.batch_time_mode is False, t is a scalar indicating the
            time step to be evaluated. When self.batch_time_mode is True, t is
            a 1-D tensor with a single element [1.0].
          y: When self.batch_time_mode is False, y is a 1-D tensor with length
            2 + k, where the first dim indicates Lambda_t, the second dim
            indicates the final time step T to be evaluated, and the remaining
            k dims indicates the features. When self.batch_time_mode is True, y
            is a 2-D tensor with batch_size * (2 + k).
        """
        self.nfe += 1
        device = next(self.parameters()).device
        Lambda_t = y.index_select(-1, torch.tensor([0]).to(device)).view(-1, 1) ## retrieve Lambda_t from y, returns as a 2-D tensor of 1 element
        T = y.index_select(-1, torch.tensor([1]).to(device)).view(-1, 1)  ## retrieve the final time step T from y, returns as a 2-D tensor of 1 element
        x = y.index_select(-1, torch.tensor(range(2, y.size(-1))).to(device))  ## retrieve features from y, returns as a 1-D tensor
        if self.use_embed:
            x = torch.mean(
                self.embed(torch.tensor(x, dtype=torch.long).to(device)),
                dim=1)
        # Rescaling trick  ## time rescaling
        # $\int_0^T f(s; x) ds = \int_0^1 T f(tT; x) dt$, where $t = s / T$
        inp = torch.cat(
            [Lambda_t,
             t.repeat(T.size()) * T,  # s = t * T; time step to be evaluated * final time step
             x.view(-1, self.feature_size)], dim=1)
        output = self.net(inp) * T  # f(tT; x) * T
        zeros = torch.zeros_like(
            y.index_select(-1, torch.tensor(range(1, y.size(-1))).to(device))
        )  ## Returns a tensor filled with the scalar value 0, with the same size as input
        output = torch.cat([output, zeros], dim=1)
        if self.batch_time_mode:
            return output
        else:
            return output.squeeze(0)

In [81]:
# define sample config
config = {
    "hidden_size": 5,
    "num_layers": 2,
    "batch_norm": False
}

In [84]:
# define odefunc
odefunc = ContextRecMLPODEFunc(
                feature_size = 4, hidden_size = config["hidden_size"], num_layers = config["num_layers"],
                batch_norm=config["batch_norm"], use_embed=False)

In [117]:
# initilize output
outputs = {}
# import ode solver
from torchdiffeq import odeint_adjoint as odeint

In [126]:
# change init_cond to floating point Tensor to fix datatype error
init_cond = init_cond.to(torch.float)
# solve ode for Lambda
outputs["Lambda"] = odeint(odefunc, init_cond, t, rtol=1e-4, atol=1e-8)[1:].squeeze()  # size: [length of t] x [batch size] x [dim of y0]  ## Solve ODE for cumulative hazard function
outputs["Lambda"] = outputs["Lambda"].view(1, outputs["Lambda"].size(-1)) ## add to fix dimension error
print(outputs)

{'Lambda': tensor([[4.2693, 5.0000, 3.0000, 4.0000, 5.0000, 6.0000]],
       grad_fn=<ViewBackward0>), 'lambda': tensor([3.2847, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
       grad_fn=<SqueezeBackward0>)}


In [127]:
# add outputs
odefunc.set_batch_time_mode(True)
outputs["lambda"] = odefunc(t[1:], outputs["Lambda"]).squeeze()
outputs["lambda"] = outputs["lambda"].view(1, outputs["lambda"].size(-1)) ## add to fix dimension error
outputs["Lambda"] = outputs["Lambda"][:, 0]
outputs["lambda"] = outputs["lambda"][:, 0] / inputs["t"]
print(outputs)

{'Lambda': tensor([4.2693], grad_fn=<SelectBackward0>), 'lambda': tensor([0.6569], grad_fn=<DivBackward0>)}


In [99]:
# debug
t[1:]

tensor([1.])

In [98]:
outputs["Lambda"].dim()

1

In [128]:
outputs["Lambda"][:, 0]

IndexError: too many indices for tensor of dimension 1

In [114]:
# debug dimsion error inside ContextRecMLPODEFunc
# convert y to 2d solves the problem
y = outputs["Lambda"]
y = y.view(1, 6)
print("y:", y)
t_1 = t[1:]
print("t_1:", t_1)

y: tensor([[4.2693, 5.0000, 3.0000, 4.0000, 5.0000, 6.0000]],
       grad_fn=<ViewBackward0>)
t_1: tensor([1.])


In [115]:
Lambda_t = y.index_select(-1, torch.tensor([0])).view(-1, 1) ## retrieve Lambda_t from y, returns as a 2-D tensor of 1 element
T = y.index_select(-1, torch.tensor([1])).view(-1, 1)  ## retrieve the final time step T from y, returns as a 2-D tensor of 1 element
x = y.index_select(-1, torch.tensor(range(2, y.size(-1))))
print("Lambda_t:", Lambda_t)
print("T: ", T)
print("x:", x)

Lambda_t: tensor([[4.2693]], grad_fn=<ViewBackward0>)
T:  tensor([[5.]], grad_fn=<ViewBackward0>)
x: tensor([[3., 4., 5., 6.]], grad_fn=<IndexSelectBackward0>)


In [116]:
feature_size = 4
inp = torch.cat([Lambda_t, t_1.repeat(T.size()) * T,  # s = t * T; time step to be evaluated * final time step
             x.view(-1, feature_size)], dim=1)
print("inp:", inp)
output = net(inp) * T  # f(tT; x) * T
print("output:", output)
zeros = torch.zeros_like(y.index_select(-1, torch.tensor(range(1, y.size(-1)))))  ## Returns a tensor filled with the scalar value 0, with the same size as input
print("zeros:", zeros)
output = torch.cat([output, zeros], dim=1)
print("output:", output)

inp: tensor([[4.2693, 5.0000, 3.0000, 4.0000, 5.0000, 6.0000]],
       grad_fn=<CatBackward0>)
output: tensor([[2.7066]], grad_fn=<MulBackward0>)
zeros: tensor([[0., 0., 0., 0., 0.]])
output: tensor([[2.7066, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<CatBackward0>)


In [109]:
print(t_1.repeat(T.size()) * T)

tensor([[5.]], grad_fn=<MulBackward0>)


In [107]:
T.size()

torch.Size([1, 1])

In [108]:
T

tensor([[5.]], grad_fn=<ViewBackward0>)

In [112]:
y.size(-1)

6