In [1]:
import data
import torch
from torch.distributions import Normal, Categorical, kl_divergence
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

# Model definitions

In [2]:
class Model(nn.Module):
    def __init__(self, num_actions, beta=10, alpha=0.0):
        super(Model, self).__init__()
        self.beta = beta
        self.params = nn.Parameter(torch.zeros(3, num_actions))
        self.optimizer = optim.SGD(self.parameters(), lr=alpha)

    def forward(self, utilities):
        self.probabilities = F.softmax(self.params, dim=1)
        joint_prior = torch.einsum('i,j,k->ijk', self.probabilities[0], self.probabilities[1], self.probabilities[2]).flatten()

        posterior = joint_prior * torch.exp(self.beta * utilities)
        posterior = posterior / posterior.sum()

        return Categorical(posterior).sample(), kl_divergence(Categorical(posterior), Categorical(joint_prior)).item()

    def learn(self, action):
        prior = Categorical(self.probabilities)
        loss = -prior.log_prob(combinations[action] + 10).sum()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

In [3]:
class RationalModel(nn.Module):
    def __init__(self):
        super(RationalModel, self).__init__()

    def forward(self, utilities):
        return torch.argmax(utilities), 0

    def learn(self, action):
        pass

In [4]:
experiments = [data.default, data.control, data.less_e, data.more_e, data.alternating, data.extreme]
num_runs = 100

# only search for weights between -10 and +10
values = torch.arange(-10, 11, 1)
combinations = torch.cartesian_prod(values, values, values)

In [5]:
def simulate(experiments, num_runs, gamma, bounded, alphas, betas):
    for num_experiment, experiment in enumerate(experiments):
        x, y = experiment()

        # precompute utilities
        optimality = ((combinations.float() @ x.t()) == y).double()
        utilities = optimality - gamma * (combinations ** 2).sum(-1, keepdims=True)

        num_correct = torch.zeros(num_runs, len(alphas), len(betas), x.shape[0])
        num_d_solution = torch.zeros(num_runs, len(alphas), len(betas), x.shape[0])
        num_e_solution = torch.zeros(num_runs, len(alphas), len(betas), x.shape[0])
        klds = torch.zeros(num_runs, len(alphas), len(betas), x.shape[0])

        for alpha_idx, alpha in enumerate(alphas):
            print(alpha_idx)
            for beta_idx, beta in enumerate(betas):
                for run in range(num_runs):
                    if bounded:
                        model = Model(values.shape[0], beta, alpha)
                    else:
                        model = RationalModel()
                    for i in range(x.shape[0]):
                        # inference
                        action, kld = model(utilities[:, i])
                        klds[run, alpha_idx, beta_idx, i] = kld
                        # adjust prior
                        model.learn(action)

                        if optimality[action, i].bool().item():
                            num_correct[run, alpha_idx, beta_idx, i] +=1

                        if optimality[action, i].bool().item() and combinations[action].pow(2).sum() < 6:
                            num_d_solution[run, alpha_idx, beta_idx, i] +=1

                        if optimality[action, i].bool().item() and combinations[action].pow(2).sum() == 6:
                            num_e_solution[run, alpha_idx, beta_idx, i] +=1
        torch.save([num_correct, num_d_solution, num_e_solution, klds], 'data/exp' + str(num_experiment) + '_gamma_' + str(gamma) + '_bounded_' + str(bounded) + '.pth')

# Simulate full model

In [6]:
gamma = 0.05
bounded = True

alphas = torch.linspace(0., 1, 21)
betas = torch.linspace(1, 50, 50)

simulate(experiments, num_runs, gamma, bounded, alphas, betas)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20


# Simulate model with no physical effort

In [7]:
gamma = 0.00
bounded = True

alphas = torch.linspace(0., 1, 21)
betas = torch.linspace(1, 50, 50)

simulate(experiments, num_runs, gamma, bounded, alphas, betas)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20


# Simulate rational model

In [8]:
gamma = 0.05
bounded = False

alphas = torch.linspace(1, 1, 1)
betas = torch.linspace(1, 1, 1)

simulate(experiments, num_runs, gamma, bounded, alphas, betas)

0
0
0
0
0
0
