In [10]:
MANIC = 2
DEPRESSED = 1
NORMAL = "normal"
MDD = "MDD"
BIPOLAR = "mixed bipolar"
MANIA = "pure mania"

In [11]:
TARGET = {NORMAL : 87.652,
 MDD : 9.679,
 BIPOLAR: 1.6629999999999998,
 MANIA: 1.006}

In [12]:
import torch
from torch.distributions.beta import Beta
import torch.optim as optim
import numpy as np

def loss_fn(d):
    # normalize the dictionary
    total = sum(d.values())
    for k in d:
        if total != 0:
            d[k] = d[k]/total
        d[k] = d[k] * 100
    
    print(f'Normalized dictionary: {d}')

    loss = 0
    for k in d:
        loss += (d[k] - TARGET[k])**2
    return torch.tensor(loss, requires_grad=True)

    

    # return torch.tensor(0.0, requires_grad=True)  # Placeholder for the actual loss function

def run_markov_sequence(transition_matrix, initial_state, length):
    sequence = [initial_state]
    for _ in range(length):
        next_state = torch.multinomial(transition_matrix[sequence[-1]], 1)[0]
        sequence.append(next_state)
    return sequence

def diagnosis(cont_depressed, manic_trans):
    if cont_depressed and manic_trans:
        return BIPOLAR
    elif cont_depressed:
        return MDD
    elif manic_trans:
        return MANIA
    else:
        return NORMAL

def update_counter(d, sequence):
    arr = np.array(sequence)
    cont_depressed = False
    manic_trans = False

    curr = 0
    for i in range(len(arr)):
        if arr[i] == DEPRESSED:
            curr += 1
        elif arr[i] == MANIC:
            manic_trans = True
            curr = 0
        else:
            curr = 0
        if curr >= 2:
            cont_depressed = True
    diag = diagnosis(cont_depressed, manic_trans)
    d[diag] += 1


# Parameters
learning_rate = 0.01
epochs = 1000
initial_state = torch.tensor(0)

# Initialize learnable parameters for the alpha and beta parameters of the beta distribution
phi = torch.tensor([[0.6, 0.6, 0.6],
                    [0.6, 0.6, 0.6],
                    [0.6, 0.6, 0.6]], requires_grad=True)  # Initial alpha parameter
lmda = torch.tensor([[10.0, 10.0, 10.0],
                     [10.0, 10.0, 10.0],
                     [10.0, 10.0, 10.0]], requires_grad=True)  # Initial beta parameter
optimizer = optim.Adam([phi, lmda], lr=learning_rate)

for epoch in range(epochs):
    optimizer.zero_grad()
    total = torch.zeros_like(phi)
    for i in range(20):
        # Sample from the beta distribution
        beta_dist = Beta(lmda * phi, lmda * (1 - phi))
        sample = beta_dist.sample()
        total += sample  # Sum the samples
    
    mean_sample = total / 20  # Calculate the mean of the samples
    # print(f'Mean sample: {mean_sample}')

    # normalize the mean sample
    transition_matrix = mean_sample/mean_sample.sum(1, keepdim=True)
    # print(f'Transition matrix {transition_matrix}')
    

    d = {NORMAL: 0, MDD: 0, BIPOLAR: 0, MANIA: 0}
    for i in range(100):
        generated_sequence = run_markov_sequence(transition_matrix, initial_state, 52)
        update_counter(d, generated_sequence)
    


    # print(f'Generated sequence: {generated_sequence}')
    loss = loss_fn(d)

    # Backpropagation
    loss.backward()
    optimizer.step()

# After training
print(f'Final Alpha: {lmda.detach() * phi.detach()}')
print(f'Final Beta: {lmda.detach() * (1 - phi.detach())}')



Normalized dictionary: {'normal': 0.0, 'MDD': 0.0, 'mixed bipolar': 100.0, 'pure mania': 0.0}
Normalized dictionary: {'normal': 0.0, 'MDD': 0.0, 'mixed bipolar': 97.0, 'pure mania': 3.0}
Normalized dictionary: {'normal': 0.0, 'MDD': 0.0, 'mixed bipolar': 100.0, 'pure mania': 0.0}
Normalized dictionary: {'normal': 0.0, 'MDD': 0.0, 'mixed bipolar': 100.0, 'pure mania': 0.0}
Normalized dictionary: {'normal': 0.0, 'MDD': 0.0, 'mixed bipolar': 100.0, 'pure mania': 0.0}
Normalized dictionary: {'normal': 0.0, 'MDD': 0.0, 'mixed bipolar': 100.0, 'pure mania': 0.0}
Normalized dictionary: {'normal': 0.0, 'MDD': 0.0, 'mixed bipolar': 99.0, 'pure mania': 1.0}
Normalized dictionary: {'normal': 0.0, 'MDD': 0.0, 'mixed bipolar': 97.0, 'pure mania': 3.0}
Normalized dictionary: {'normal': 0.0, 'MDD': 0.0, 'mixed bipolar': 100.0, 'pure mania': 0.0}
Normalized dictionary: {'normal': 0.0, 'MDD': 0.0, 'mixed bipolar': 97.0, 'pure mania': 3.0}
Normalized dictionary: {'normal': 0.0, 'MDD': 0.0, 'mixed bipola

KeyboardInterrupt: 