In [19]:
import torch
from torch.distributions.beta import Beta
import torch.optim as optim

def loss_fn(generated_sequence):
    return torch.tensor(0.0, requires_grad=True)  # Placeholder for the actual loss function

def run_markov_sequence(transition_matrix, initial_state, length):
    sequence = [initial_state]
    for _ in range(length):
        next_state = torch.multinomial(transition_matrix[sequence[-1]], 1)[0]
        sequence.append(next_state)
    return sequence

# Parameters
learning_rate = 0.01
epochs = 1000
initial_state = torch.tensor(0)

# Initialize learnable parameters for the alpha and beta parameters of the beta distribution
phi = torch.tensor([[0.6, 0.6, 0.6],
                    [0.6, 0.6, 0.6],
                    [0.6, 0.6, 0.6]], requires_grad=True)  # Initial alpha parameter
lmda = torch.tensor([[10.0, 10.0, 10.0],
                     [10.0, 10.0, 10.0],
                     [10.0, 10.0, 10.0]], requires_grad=True)  # Initial beta parameter
optimizer = optim.Adam([phi, lmda], lr=learning_rate)

for epoch in range(epochs):
    optimizer.zero_grad()
    total = torch.zeros_like(phi)
    for i in range(20):
        # Sample from the beta distribution
        beta_dist = Beta(lmda * phi, lmda * (1 - phi))
        sample = beta_dist.sample()
        total += sample  # Sum the samples
    
    mean_sample = total / 20  # Calculate the mean of the samples
    # print(f'Mean sample: {mean_sample}')

    # normalize the mean sample
    transition_matrix = mean_sample/mean_sample.sum(1, keepdim=True)
    # print(f'Transition matrix {transition_matrix}')
    


    # Loss: Mean squared error between current mean and target mean
    generated_sequence = run_markov_sequence(transition_matrix, initial_state, 52)
    # print(f'Generated sequence: {generated_sequence}')
    loss = loss_fn(generated_sequence)
    
    # Backpropagation
    loss.backward()
    optimizer.step()

# After training
print(f'Final Alpha: {lmda.detach() * phi.detach()}')
print(f'Final Beta: {lmda.detach() * (1 - phi.detach())}')



Final Alpha: tensor([[6., 6., 6.],
        [6., 6., 6.],
        [6., 6., 6.]])
Final Beta: tensor([[4.0000, 4.0000, 4.0000],
        [4.0000, 4.0000, 4.0000],
        [4.0000, 4.0000, 4.0000]])
