In [110]:
import torch
from torch import nn
from torch import functional as F

import numpy as np
import pandas as pd

from dataclasses import dataclass
from torchvision import transforms
from torch.distributions import Bernoulli
import sympy

In [15]:
@dataclass
class Config:    
    #data
    number_of_spins:int = 3
    number_of_states:int = 4
    sample_size:int = 100
        
    bernoulli_probability_0 = 0.2
    bernoulli_probability_0 = 0.8
    
    #process
    gamma:float = .2 
    
    #training
    number_of_epochs = 10
    learning_rate = 0.01
    batch_size:int = 10
    
config = Config()

In [123]:
batchdata = torch.randint(0,config.number_of_states,
                          (config.batch_size,config.number_of_spins)).float()

x0 = torch.randint(0,config.number_of_states,
                   (config.batch_size,config.number_of_spins)).float()

xt = torch.randint(0,config.number_of_states,
                   (config.batch_size,config.number_of_spins)).float()

x1 = torch.randint(0,config.number_of_states,
                   (config.batch_size,config.number_of_spins)).float()

x = torch.arange(0,config.number_of_states)
x = x[None,None,:].repeat((config.batch_size,config.number_of_spins,1)).float()

t = torch.rand(config.batch_size)

In [91]:
right_shape = lambda x: x if len(x.shape) == 3 else x[:,:,None]

In [92]:
 torch.full((config.batch_size,),1.)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [98]:
t1 = right_time_size(1.)
t0 = right_time_size(0.)
integral_t0 = beta_integral(config.gamma,t1,t0)
w_t0  = torch.exp(-S*integral_t0)
w_t0

tensor([0.4493, 0.4493, 0.4493, 0.4493, 0.4493, 0.4493, 0.4493, 0.4493, 0.4493,
        0.4493])

In [128]:
def conditional_probability(config,x,x0,t,t0):
    """
    
    \begin{equation}
    P(x(t) = i|x(t_0)) = \frac{1}{s} + w_{t,t_0}\left(-\frac{1}{s} + \delta_{i,x(t_0)}\right)
    \end{equation}

    \begin{equation}
    w_{t,t_0} = e^{-S \int_{t_0}^{t} \beta(r)dr}
    \end{equation}

    """
    right_shape = lambda x: x if len(x.shape) == 3 else x[:,:,None]
    right_time_size = lambda t: t if isinstance(t,torch.Tensor) else torch.full((config.batch_size,),t)
    
    t1 = right_time_size(1.)
    t0 = right_time_size(0.)
    
    S = config.number_of_states

    integral_t0 = beta_integral(config.gamma,t,t0)
    w_t0  = torch.exp(-S*integral_t0)

    x = right_shape(x)
    x0 = right_shape(x0)

    delta_x = (x == x0).float()
    probability = 1./S + w_t0[:,None,None]*( (-1./S) + delta_x )
    
    return probability

\begin{equation}
    P(x) = e^{-2\gamma\tau}(P_0 - \frac{1}{2}) + \frac{1}{2}
\end{equation}

\begin{equation}
P(x(t) = i|x(t_0)) = \frac{1}{s} + w_{t,t_0}\left(-\frac{1}{s} + \delta_{i,x(t_0)}\right)
\end{equation}

\begin{equation}
w_{t,t_0} = e^{-S \int_{t_0}^{t} \beta(r)dr}
\end{equation}

In [6]:
def beta_integral(gamma,t1,t0):
    """
    Dummy integral for constant rate
    """
    interval = t1 - t0 
    integral = gamma*interval
    return integral

In [154]:
def conditional_transition_probability(config,x,x1,x0,t):
    """
    \begin{equation}
    P(x_t=x|x_0,x_1) = \frac{p(x_1|x_t=x) p(x_t = x|x_0)}{p(x_1|x_0)}
    \end{equation}
    """
    
    P_x_to_x1 = conditional_probability(config,x1,x,t=1.,t0=t)
    P_x0_to_x = conditional_probability(config,x,x0,t=t,t0=0.)

    P_x0_to_x1 = conditional_probability(config,x1,x0,t=1.,t0=0.)

    conditional_transition_probability = (P_x_to_x1*P_x0_to_x)/P_x0_to_x1
    return conditional_transition_probability

In [121]:
#torch.softmax(conditional_transition_probability,dim=-1)

\begin{equation}
    P(x_t=x|x_0,x_1) = \frac{p(x_1|x_t=x) p(x_t = x|x_0)}{p(x_1|x_0)}
\end{equation}

$\newcommand{\*}[1]{\bar{\mathbf{#1}}}$

\begin{equation}
f_t(\*x'|\*x,\*x_1) = \frac{p(\*x_1|x_t=\*x')}{p(\*x_1|x_t=\*x)}f_t(\*x'|\*x)
\end{equation}


In [153]:
def constant_rate(config,x,t):
    right_time_size = lambda t: t if isinstance(t,torch.Tensor) else torch.full((config.batch_size,),t)
    t = right_time_size(t)

    batch_size = x.size(0)
    dimension = x.size(1)
    
    assert batch_size == t.size(0)
    
    rate_ = torch.full((batch_size,dimension,config.number_of_states),
                       config.gamma)                
    return rate_

def conditional_transition_rate(config,x,x1,t):
    """
    \begin{equation}
    f_t(\*x'|\*x,\*x_1) = \frac{p(\*x_1|x_t=\*x')}{p(\*x_1|x_t=\*x)}f_t(\*x'|\*x)
    \end{equation}
    """
    where_to_x = torch.arange(0,config.number_of_states)
    where_to_x = where_to_x[None,None,:].repeat((config.batch_size,config.number_of_spins,1)).float()
    
    P_xp_to_x1 = conditional_probability(config,x1,where_to_x,t=1.,t0=t)
    P_x_to_x1 =  conditional_probability(config,x1,x,t=1.,t0=t)
    
    forward_rate = constant_rate(config,x,t)
    
    rate_transition = (P_xp_to_x1/P_x_to_x1)*forward_rate
    
    return rate_transition

In [151]:
rate = conditional_transition_rate(config,x,x1,t)