In [1]:
import numpy as np
import torch
import torch.nn as nn
from torchdiffeq import odeint
import torchvision
import torchvision.transforms as transforms
from general import *
import time

In [2]:
# def func(t, y):
#     dy = -y
#     return dy

class Func:
    def __init__(self, input):
        self.cvA = nn.Conv2d(3, 3, 3, padding='same')
        self.cvB = nn.Conv2d(3, 3, 3, padding='same')
        self.Ib = 0.1
        self.input = input
    
    def __call__(self, t, y):
        return self.forward(t, y)

    def cnn(self, x):
        return 0.5 * (abs(x + 1) - abs(x - 1))

    def forward(self, t, y):
        dy = -y + self.Ib + self.cvB(self.input) + self.cvA(self.cnn(y))
        return dy


In [3]:

batch_size = 1
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                        shuffle=True, num_workers=2)
_, (input, label) = next(enumerate(train_loader))    
input = torch.tensor(input, dtype=torch.float)
print(input.shape)

Files already downloaded and verified
torch.Size([1, 3, 32, 32])


  input = torch.tensor(input, dtype=torch.float)


In [4]:
func = Func(input=input)
t = np.linspace(0, 10, 10)
print('dt: ', t[1]-t[0])

y0 = input
start1 = time.time()
y = odesolver(y0, t, dt=t[1]-t[0], func=func)
end1 = time.time()
print(f'\nEuler time: {(end1-start1):.4f}')
print(y[-1][0, 0, 1, :10])

start2 = time.time()
y = RK4solver(y0, t, dt=t[1]-t[0], func=func)
end2 = time.time()
print(f'\nRK4 time: {(end2-start2):.4f}')
print(y[-1][0, 0, 1, :10])

start3 = time.time()
y = RK4_altstep_solver(y0, t, dt=t[1]-t[0], func=func)
end3 = time.time()
print(f'\nRK4_altstep time: {(end3-start3):.4f}')
print(y[-1][0, 0, 1, :10])

start4 = time.time()
ode_result = odeint(func, y0=torch.tensor(y0), t=torch.tensor(t), method='rk4')
end4 = time.time()
print(f'\ntorchdiff_ode: {(end4-start4):.4f}')
print(ode_result[-1][0, 0, 1, :10])

dt:  1.1111111111111112

Euler time: 0.3276
tensor([-0.5963, -0.7258, -0.7618, -0.7838, -0.7638, -0.7838, -0.7964, -0.7832,
        -0.7869, -0.7729], grad_fn=<SliceBackward0>)

RK4 time: 0.0072
tensor([-0.6081, -0.7313, -0.7534, -0.7719, -0.7760, -0.7850, -0.7867, -0.7882,
        -0.7809, -0.7745], grad_fn=<SliceBackward0>)

RK4_altstep time: 0.0168
tensor([-0.6080, -0.7314, -0.7534, -0.7720, -0.7761, -0.7850, -0.7869, -0.7883,
        -0.7810, -0.7745], grad_fn=<SliceBackward0>)

torchdiff_ode: 0.0068
tensor([-0.6080, -0.7314, -0.7534, -0.7720, -0.7761, -0.7850, -0.7869, -0.7883,
        -0.7810, -0.7745], grad_fn=<SliceBackward0>)


  ode_result = odeint(func, y0=torch.tensor(y0), t=torch.tensor(t), method='rk4')


In [5]:
def odeint_warpper(t):
    return odeint(func, y0=torch.tensor(y0), t=torch.tensor(t, dtype=torch.float), method='rk4')

In [6]:
def run_multiprocessing():
    from multiprocessing import get_context
    with get_context("spawn").Pool(2) as pool:
        # ode_result = pool.map(odeint_warpper, [t[i*nfine: (i+1)*nfine] for i in range(t_fine_len)])
        ode_result = pool.map(odeint_warpper, [t[: 5], t[5: 2*5]])
    return ode_result

In [12]:
from pathos.multiprocessing import ProcessingPool as Pool
import multiprocessing


func = Func(input=input)
step_fine = 1
step_coarse = 5
t = np.arange(0, 10, step_fine)
nfine = step_coarse // step_fine
print(f'nfine: {nfine}')
t_fine_len = t.shape[0] // nfine
t_coarse = [t[i*nfine] for i in range(t_fine_len)]
t_coarse = t_coarse if t_coarse[-1] == t[-1] else t_coarse + [t[-1]]
t_coarse = np.array(t_coarse)

# t_coarse = np.array([i for i in np.arange(0, t[-1]+step_coarse-1, step_coarse)])

print(f'len.t:{t.shape[0]}, len.t_coarse:{t_coarse.shape[0]}, dt: {t[1]-t[0]}, dt_coarse: {t_coarse[1]-t_coarse[0]}')
print(t_coarse)
y0 = input

start_coarse = time.time()
y_coarse = odesolver(y0, t_coarse, dt=t_coarse[1]-t_coarse[0], func=func)
end_coarse = time.time()
print(f'\nEuler time (coarse): {(end_coarse-start_coarse):.4f}')
print(y[-1][0, 0, 1, :10])


start_fine = time.time()
# ode_result = run_multiprocessing()
with Pool(2) as pool:
    # ode_result = pool.map(odeint_warpper, [t[i*nfine: (i+1)*nfine] for i in range(t_fine_len)])
    ode_result = pool.map(odeint_warpper, [t[: nfine], t[nfine: 2*nfine]])

# for i in range(1):
#     ode_result.append(odeint(func, y0=torch.tensor(y0), t=torch.tensor(t[i*nfine: (i+1)*nfine], dtype=torch.float), method='rk4'))

end_fine = time.time()
print(f'\ntorchdiff_ode time (fine): {(end_fine-start_fine):.4f}')
print(f'total time: {(end_fine-start_coarse):.4f}')
print(ode_result[-1][-1][0, 0, 1, :10])

start1 = time.time()
y = odesolver(y0, t, dt=t[1]-t[0], func=func)
end1 = time.time()
print(f'\nEuler time: {(end1-start1):.4f}')
print(y[-1][0, 0, 1, :10])

start4 = time.time()
ode_result = odeint(func, y0=torch.tensor(y0), t=torch.tensor(t, dtype=torch.float), method='rk4')
end4 = time.time()
print(f'\ntorchdiff_ode: {(end4-start4):.4f}')
print(ode_result[-1][0, 0, 1, :10])

nfine: 5
len.t:10, len.t_coarse:3, dt: 1, dt_coarse: 5
[0 5 9]

Euler time (coarse): 0.0024
tensor([-0.7118, -0.6931, -0.7951, -0.7403, -0.7537, -0.7395, -0.7448, -0.7996,
        -0.7755, -0.7971], grad_fn=<SliceBackward0>)

torchdiff_ode time (fine): 0.0213
total time: 0.0250
tensor([-0.7659,  0.4542,  0.6074,  0.6161,  0.6146,  0.6054,  0.5958,  0.5832,
         0.5756,  0.5638], grad_fn=<SliceBackward0>)

Euler time: 0.0061
tensor([-1.9297, -3.7464, -3.7872, -3.7833, -3.7822, -3.7705, -3.7626, -3.7565,
        -3.7776, -3.8094], grad_fn=<SliceBackward0>)

torchdiff_ode: 0.0109
tensor([-1.9296, -3.7429, -3.7851, -3.7820, -3.7810, -3.7695, -3.7614, -3.7551,
        -3.7761, -3.8081], grad_fn=<SliceBackward0>)


  ode_result = odeint(func, y0=torch.tensor(y0), t=torch.tensor(t, dtype=torch.float), method='rk4')
