In [23]:
import os
import sys
import torch
from torch import nn
from torch import functional as F

import numpy as np
import pandas as pd

import sympy
from dataclasses import dataclass
from torchvision import transforms
from torch.optim.adam import Adam
from torch.distributions import Bernoulli
from torch.distributions import Categorical
from torch.nn.functional import softplus,softmax

In [None]:
nn.Softmax()

In [3]:
from torch.utils.tensorboard import SummaryWriter

In [4]:
from conditional_rate_matching.configs.config_files import ExperimentFiles
from conditional_rate_matching.data.states_dataloaders import sample_categorical_from_dirichlet
from graph_bridges.models.temporal_networks.embedding_utils import transformer_timestep_embedding

In [5]:
@dataclass
class Config:

    # data
    number_of_spins :int = 3
    number_of_states :int = 4
    sample_size :int = 200

    dirichlet_alpha_0 :float = 0.1
    dirichlet_alpha_1 :float = 100.

    bernoulli_probability_0 :float = 0.2
    bernoulli_probability_0 :float = 0.8

    # process
    gamma :float = .9

    # model

    # temporal network
    time_embed_dim :int = 9
    hidden_dim :int = 50

    # rate
    loss:str = "classifier" # classifier,naive
    
    # training
    number_of_epochs = 1
    learning_rate = 0.01
    batch_size :int = 5
    device = "cuda:0"

    #pipeline
    number_of_steps:int = 20
    num_intermediates:int = None

    def __post_init__(self):
        self.num_intermediates = int(.5*self.number_of_steps)
        
config = Config()

In [6]:
from torch.utils.data import DataLoader,TensorDataset
from torch.distributions import Dirichlet
import torch

# Parameters
dataset_0 = sample_categorical_from_dirichlet(probs=None,
                                              alpha=config.dirichlet_alpha_0,
                                              sample_size=config.sample_size,
                                              dimension=config.number_of_spins,
                                              number_of_states=config.number_of_states)
tensordataset_0 = TensorDataset(dataset_0)
dataloader_0 = DataLoader(tensordataset_0,batch_size=config.batch_size)

dataset_1 = sample_categorical_from_dirichlet(probs=None,
                                              alpha=config.dirichlet_alpha_1,
                                              sample_size=183,
                                              dimension=config.number_of_spins,
                                              number_of_states=config.number_of_states)
tensordataset_1 = TensorDataset(dataset_1)
dataloader_1 = DataLoader(tensordataset_1,batch_size=config.batch_size)

In [7]:
size_dataset_0 = len(tensordataset_0)
size_dataset_1 = len(tensordataset_1)

In [8]:
def beta_integral(gamma, t1, t0):
    """
    Dummy integral for constant rate
    """
    interval = t1 - t0
    integral = gamma * interval
    return integral

def conditional_probability(config, x, x0, t, t0):
    """

    \begin{equation}
    P(x(t) = i|x(t_0)) = \frac{1}{s} + w_{t,t_0}\left(-\frac{1}{s} + \delta_{i,x(t_0)}\right)
    \end{equation}

    \begin{equation}
    w_{t,t_0} = e^{-S \int_{t_0}^{t} \beta(r)dr}
    \end{equation}

    """
    right_shape = lambda x: x if len(x.shape) == 3 else x[:, :, None]
    right_time_size = lambda t: t if isinstance(t, torch.Tensor) else torch.full((x.size(0),), t)

    t = right_time_size(t).to(x0.device)
    t0 = right_time_size(t0).to(x0.device)

    S = config.number_of_states
    integral_t0 = beta_integral(config.gamma, t, t0)

    w_t0 = torch.exp(-S * integral_t0)

    x = right_shape(x)
    x0 = right_shape(x0)

    delta_x = (x == x0).float()
    probability = 1. / S + w_t0[:, None, None] * ((-1. / S) + delta_x)

    return probability

\begin{equation}
P(x(t) = i|x(t_0) = j) = \frac{1}{s} + w_{t,t_0}\left(-\frac{1}{s} + \delta_{i,x(t_0)=j}\right)
\end{equation}

\begin{equation}
w_{t,t_0} = e^{-S \int_{t_0}^{t} \beta(r)dr}
\end{equation}

In [19]:
def conditional_transition_probability(config, x, x1, x0, t):
    """
    \begin{equation}
    P(x_t=x|x_0,x_1) = \frac{p(x_1|x_t=x) p(x_t = x|x_0)}{p(x_1|x_0)}
    \end{equation}
    """

    P_x_to_x1 = conditional_probability(config, x1, x, t=1., t0=t)
    P_x0_to_x = conditional_probability(config, x, x0, t=t, t0=0.)
    P_x0_to_x1 = conditional_probability(config, x1, x0, t=1., t0=0.)

    conditional_transition_probability = (P_x_to_x1 * P_x0_to_x) / P_x0_to_x1
    return conditional_transition_probability

def constant_rate(config, x, t):
    right_time_size = lambda t: t if isinstance(t, torch.Tensor) else torch.full((x.size(0),), t)
    t = right_time_size(t).to(x.device)

    batch_size = x.size(0)
    dimension = x.size(1)

    assert batch_size == t.size(0)

    rate_ = torch.full((batch_size, dimension, config.number_of_states),
                       config.gamma)
    return rate_

def conditional_transition_rate(config, x, x1, t):
    """
    \begin{equation}
    f_t(\*x'|\*x,\*x_1) = \frac{p(\*x_1|x_t=\*x')}{p(\*x_1|x_t=\*x)}f_t(\*x'|\*x)
    \end{equation}
    """
    where_to_x = torch.arange(0, config.number_of_states)
    where_to_x = where_to_x[None, None, :].repeat((x.size(0), config.number_of_spins, 1)).float()
    where_to_x = where_to_x.to(x.device)

    P_xp_to_x1 = conditional_probability(config, x1, where_to_x, t=1., t0=t)
    P_x_to_x1 = conditional_probability(config, x1, x, t=1., t0=t)

    forward_rate = constant_rate(config, x, t).to(x.device)
    rate_transition = (P_xp_to_x1 / P_x_to_x1) * forward_rate

    return rate_transition

In [20]:
#torch.softmax(conditional_transition_probability,dim=-1)

\begin{equation}
    P(x_t=x|x_0,x_1) = \frac{p(x_1|x_t=x) p(x_t = x|x_0)}{p(x_1|x_0)}
\end{equation}

$\newcommand{\*}[1]{\bar{\mathbf{#1}}}$

\begin{equation}
f_t(\*x'|\*x,\*x_1) = \frac{p(\*x_1|x_t=\*x')}{p(\*x_1|x_t=\*x)}f_t(\*x'|\*x)
\end{equation}


In [21]:
def uniform_pair_x0_x1(batch_1,batch_0):
    x_0 = batch_0[0]
    x_1 = batch_1[0]
    
    batch_size_0 = x_0.size(0)
    batch_size_1 = x_1.size(0)
    
    batch_size = min(batch_size_0,batch_size_1)
    
    x_0 = x_0[:batch_size,:]
    x_1 = x_1[:batch_size,:]
    
    return x_1,x_0

In [26]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

x_to_go = torch.arange(0,config.number_of_states)
x_to_go = x_to_go[None,None,:].repeat((config.batch_size,config.number_of_spins,1)).float()
x_to_go = x_to_go.to(device)

if config.loss == "naive":
    model = ConditionalBackwardRate(config,device)
    loss_fn = nn.MSELoss()
elif config.loss == "classifier":
    model = ClassificationBackwardRate(config, device).to(device)
    loss_fn = nn.CrossEntropyLoss()

# initialize
experiment_files = ExperimentFiles(experiment_name="crm",
                                   experiment_type="dirichlet",
                                   experiment_indentifier="test2",
                                   delete=True)
experiment_files.create_directories()
writer = SummaryWriter(experiment_files.tensorboard_path)
optimizer = Adam(model.parameters(),lr=config.learning_rate)

number_of_training_steps = 0
for epoch in range(config.number_of_epochs):
    for batch_1, batch_0 in zip(dataloader_1, dataloader_0):
        
        #data pair and time sample
        x_1,x_0 = uniform_pair_x0_x1(batch_1,batch_0)
        x_0 = x_0.float().to(device)
        x_1 = x_1.float().to(device)
        
        batch_size = x_0.size(0)
        time = torch.randn(batch_size).to(device)
        
        #sample x from z
        transition_logits = conditional_transition_probability(config,x_to_go,x_1,x_0,time)
        transition_probs = torch.softmax(transition_logits,dim=-1)
        sampled_x = Categorical(transition_probs).sample().to(device)
        
        # conditional rate
        if config.loss == "naive":
            conditional_rate = conditional_transition_rate(config,sampled_x,x_1,time)
            nn.so
            model_rate = model(sampled_x,time)
            loss = loss_fn(model_rate,conditional_rate)
        elif config.loss == "classifier":
            model_classification = model(x_1,time)                        
            loss = loss_fn(model_classification.view(-1, config.number_of_states),
                           sampled_x.view(-1))
        
        writer.add_scalar('training loss', loss.item(), number_of_training_steps)

        # optimization
        optimizer.zero_grad()
        loss = loss.mean()
        loss.backward()
        optimizer.step()
        number_of_training_steps += 1
        
        if number_of_training_steps % 100 == 0:
            print(f"loss {round(loss.item(),2)}")

RuntimeError: The size of tensor a (3) must match the size of tensor b (5) at non-singleton dimension 0

In [29]:
classification_model = ClassificationBackwardRate(config,device)

In [30]:
classification_model(x_1, time)

tensor([[[-1.7164,  0.1773, -0.3872,  0.8108],
         [-0.6507, -0.4333, -1.2023, -0.7058],
         [ 0.4041,  0.1171, -0.3459, -0.8481]],

        [[-0.2368, -0.0823,  0.0034,  0.7070],
         [-0.6131,  0.2125, -0.1449, -0.3256],
         [ 0.6285,  0.3368, -1.2210, -0.6104]],

        [[-0.5464, -0.1911, -1.0401,  0.7035],
         [-0.8621, -0.1124, -0.7936, -1.1791],
         [ 0.8389, -0.4177, -0.7438, -0.8912]],

        [[-1.9241, -0.3258, -1.1273,  0.9582],
         [-0.8725, -0.0349, -0.7608, -1.3455],
         [ 0.5610,  0.3760, -0.7450, -1.1535]],

        [[-1.2245,  0.3802, -0.1008,  0.5919],
         [-0.5836, -0.4689, -1.1827, -0.2795],
         [ 0.1145,  0.1531, -0.1293, -0.4495]]], device='cuda:0',
       grad_fn=<ReshapeAliasBackward0>)

In [None]:
sampled_x.shape

In [14]:
class ClassificationBackwardRate(nn.Module):
    
    def __init__(self,config,device):
        super().__init__()
        
        self.S = config.number_of_states
        self.D = config.number_of_spins
        self.time_embed_dim = config.time_embed_dim
        self.hidden_layer = config.hidden_dim
        self.dimension = self.D
        self.num_states = self.S
        
        self.expected_data_shape = [config.number_of_spins]
        self.define_deep_models()
        self.init_weights()
        
    def define_deep_models(self):
        self.f1 = nn.Linear(self.dimension, self.hidden_layer)
        self.f2 = nn.Linear(self.hidden_layer + self.time_embed_dim, self.dimension * self.num_states)
        
        #self.f1 = nn.Linear(self.dimension, self.hidden_layer)
        #self.f2 = nn.Linear(self.hidden_layer + self.time_embed_dim, self.dimension * self.num_states)
        
    def to_go(self,x,t):
        x_to_go = torch.arange(0,self.S)
        x_to_go = x_to_go[None,None,:].repeat((batch_size,self.D,1)).float()
        x_to_go = x_to_go.to(device)
        return x_to_go
        
    def classify(self,x,times):
        batch_size = x.shape[0]
        time_embbedings = transformer_timestep_embedding(times,
                                                         embedding_dim=self.time_embed_dim)

        step_one = self.f1(x)
        step_two = torch.concat([step_one, time_embbedings], dim=1)
        rate_logits = self.f2(step_two)
        rate_logits = rate_logits.reshape(batch_size,self.dimension,self.num_states)

        return rate_logits

    def forward(self,x,t):
        right_shape = lambda x: x if len(x.shape) == 3 else x[:, :, None]
        right_time_size = lambda t: t if isinstance(t, torch.Tensor) else torch.full((x.size(0),), t).to(x.device)    
    
        batch_size = x.size(0)
        
        w_1t = beta_integral(config.gamma, right_time_size(1.), right_time_size(t))
        A = 1.
        B = (w_1t*self.S)/(1.-w_1t)
        C = w_1t 
        
        x_to_go = self.to_go(x,t)
        x_to_go = x_to_go.view((batch_size*self.S,self.D))
        rate_logits = self.classify(x,time)  
        return rate_logits
    
    def init_weights(self):
        nn.init.xavier_uniform_(self.f1.weight)
        nn.init.xavier_uniform_(self.f2.weight)

In [15]:
class ConditionalBackwardRate(nn.Module):
    """
    """
    def __init__(self,config,device):
        super().__init__()
        self.expected_data_shape = [config.number_of_spins]
        
        self.temporal_network = TemporalMLP(dimensions=config.number_of_spins,
                                       number_of_states=config.number_of_states,
                                       time_embed_dim=config.time_embed_dim,
                                       hidden_dim=config.hidden_dim,
                                       device=device).to(device)
        
        #self.logits_to_rates = nn.Linear(self.temporal_network_output_size,)
        
    def forward(self,x,time):
        batch_size = x.size(0)
        #================================
        #
        expected_data_shape_ = torch.Size([batch_size] + self.expected_data_shape)
        
        temporal_network_logits = self.temporal_network(x,time)
        rates_ = softplus(temporal_network_logits)
        
        return rates_

In [16]:
class TemporalMLP(nn.Module):

    def __init__(self,dimensions,number_of_states,time_embed_dim,hidden_dim,device):
        super().__init__()

        self.time_embed_dim = time_embed_dim
        self.hidden_layer = hidden_dim
        self.num_states = number_of_states
        self.dimension = dimensions    
        self.expected_output_shape = [self.dimension,self.num_states]
    
        self.define_deep_models()
        self.init_weights()
        
        self.device = device
        self.to(self.device)
        

    def define_deep_models(self):
        # layers
        self.f1 = nn.Linear(self.dimension, self.hidden_layer)
        self.f2 = nn.Linear(self.hidden_layer + self.time_embed_dim, self.dimension * self.num_states)

    def forward(self,x,times):
        batch_size = x.shape[0]
        time_embbedings = transformer_timestep_embedding(times,
                                                         embedding_dim=self.time_embed_dim)

        step_one = self.f1(x)
        step_two = torch.concat([step_one, time_embbedings], dim=1)
        rate_logits = self.f2(step_two)
        rate_logits = rate_logits.reshape(batch_size,self.dimension,self.num_states)

        return rate_logits

    def init_weights(self):
        nn.init.xavier_uniform_(self.f1.weight)
        nn.init.xavier_uniform_(self.f2.weight)