In [385]:
import torch
# import torch.optim as optim
from torch import distributions as pyd
import numpy as np
# import environment as env_setup

In [386]:
# write a 6x6 tensor with half zeros and half ones randomly
mdp_params = np.load("mdp/mdp_25.npz")

reward = (np.max(mdp_params['reward']) - mdp_params['reward']) / np.max(mdp_params['reward'])
mu = mdp_params['mu']
p_transition = mdp_params['p_transition'] # S * S' * A
states = mdp_params['states']
actions = mdp_params['n_actions']
H = 100
K = 10
D = 50

# eta = (H^2 * int(states) * int(actions) * K + H^4 * (K + D)) ** (-(1/2))
eta = 0.05
print('eta', eta)

eta 0.05


In [417]:
config = {
        'action_dim': int(actions),
        'state_dim': int(states),
        'transition_function': torch.from_numpy(p_transition),
        'initial_state_distribution': torch.from_numpy(mu),
        'reward_function': torch.from_numpy(reward),

        'training horizon H': H,
        'episodes K': K,
        'eta': eta,
        # 'gamma': 2 * eta * H,   
        'gamma' : 0.005,

        'max_delay': 0
}

In [418]:
config['gamma']

0.005

In [677]:
class DAPO:
    def __init__(self, config): 
        self.transition_function = config['transition_function']
        self.initial_state_distribution = config['initial_state_distribution']
        self.H = config["training horizon H"]
        self.K = config["episodes K"]
        self.reward_function = config['reward_function']
        
        # learning rate, exploration parameter
        self.eta = config["eta"]
        self.gamma = config["gamma"]

        # initialize space dimensions
        self.A = config['action_dim']
        self.S = config['state_dim']

        # self.policy_history = torch.zeros(self.S, self.A, self.H, self.K)
        self.policy_history = (1.0 / self.A) * torch.ones(self.S, self.A, self.H, self.K)

        self.delay_dict = {}
        for i in range(self.K): self.delay_dict[i] = []

        self.traj_history = []

        self.d = config['max_delay']

    def sample_from_logits(self, logits):
        sample = pyd.Categorical(logits=logits).sample()
        return sample.item()

    # WORKS
    def play_one_step(self, current_state, policy, h):  
        # Play one step in the environment
        
        # sample from policy at timestep h
        action = self.sample_from_logits(policy[current_state, :, h])
        # sample from transition to get next state
        next_state = self.sample_from_logits(self.transition_function[:, current_state, action])
        # get reward for state and action
        reward = self.reward_function[current_state, action]
        return action, next_state, reward

    # WORKS
    def play_episode(self, policy):
        traj = []
        # sample initial state
        current_state = self.sample_from_logits(self.initial_state_distribution)
        
        # for 0, ..., H-1 play a step and set the current state to the next state
        for h in range(self.H):
            action, next_state, reward = self.play_one_step(current_state, policy, h)
            traj.append((current_state, action, reward))
            current_state = next_state
        return traj
    
    # WORKS
    def observe_feedback(self, traj):
        # returns tensor of H * 1 rewards from a given trajectory
        rewards = torch.zeros(self.H)
        for h in range(self.H - 1):
            s, a = traj[h][0], traj[h][1]
            rewards[h] = self.reward_function[s, a]
        return rewards


    # rewrite the get_n_step trajectory
    # def get_n_step_transition (self, h, k):
    #     p = torch.zeros(self.S, self.S):
    #     for i 
    
    # WORKS (ish)
    def get_one_step_transitions (self, k):
        # returns a tensor of shape S * S * H which gives the one step transition probabilities
        # from state i to j at each timestep h
        p = torch.zeros(self.S, self.S, self.H)
        for h in range(self.H):
            for a in range(self.A):
                p[:, :, h] += self.policy_history[:, a, h, k] * self.transition_function[:, :, a]
        return p
    
    # WORKS (ish)
    def get_occupancy(self, h, k):
        # get the n-step transition probabilities
        # p = self.get_one_step_transition(h, k)

        p = self.get_one_step_transitions(k)

        # occupancy measure is S * S

        p_out = torch.zeros(self.S, self.S)
        # p_out[:, :] = torch.ones(self.S, self.S)
        p_out = p[:, :, 0]

        for i in range(2, h+1):
            for s in range(self.S):
                p_out = p_out * p[:, :, i-1]

        # # use dynamic programming to get the n-step transition probabilities
        # for i in range(1, h+1):
        #     for s in range(self.S):
        #         # for s_p in range(self.S):
        #         #     p_out[s, s_p, i] = p_out[s, s_p, i-1] * p[s, s_p, i-1]
        #         p_out[:, :, i] += p_out[:, s, i-1] * p[:, :, i]

        return p_out

        # for i in range

        # for i in range(1, h + 1):
        #     p_h @= self.get_one_step_transition(i, k)
        # p_h = torch.linalg.matrix_power(p, h)

        # return p_h

        # adjust for the fact that initial state is stochastic
        # p_adj = torch.zeros(self.S)

        # for s in range(self.S):
        #     p_adj += self.initial_state_distribution[s] * p_out[s, :, h]

        # return p_adj

        # # p_adj represents the unconditional n-step transition probabilities
        # # p_adj[i] is the probability of being in state i after h steps

        # # occ_measure is a tensor of shape S * A
        # occ_measure = p_adj.repeat(self.A, 1).T * self.policy_history[:, :, h, k]
        
        # # we want the sum over actions
        # occ_measure = torch.sum(occ_measure, dim = 1)
        # return p_adj, occ_measure

    def unzip_trajectory(self, traj):
        l = [list(t) for t in zip(*traj)]
        return l[0], l[1], l[2]

    def get_r(self, j_policy, k_policy):
        # takes in two tensors of shape S * A * H
        # outputs a tensor of shape S * A * H
        return j_policy / torch.maximum(j_policy, k_policy)

    def get_b(j_policy, k_policy, r, occupancies):
        # outputs a tensor of shape S * H
        output = torch.zeros(self.S, self.H)
        for a in range(self.A):
            output += (3 * self.gamma * self.H * k_policy[:, a, :] * r[:, a, :]) / (occupancies * j_policy[:, a, :] + self.gamma)

    def train_one_step(self, k):
        delay = np.random.choice(np.arange(k, k + self.d + 1))
        self.delay_dict[delay].append(k)
        delayed = self.delay_dict[k]

        # play episode k with policy at k
        k_trajectory = self.play_episode(self.policy_history[:,:,:,k])
        self.traj_history.append(k_trajectory)
        k_rewards = self.observe_feedback(k_trajectory)

        # print the total rewards from episode k
        print(f"Episode: {k}, Total Reward: {torch.sum(k_rewards)}")

        # initialize Q and B
        Q = torch.zeros(self.S, self.A, self.H, len(delayed))
        B = torch.zeros(self.S, self.A, self.H + 1, len(delayed))

        # the j are episode indexes for which we can now 'observe' the rewards
        for j in delayed:
            j_trajectory = self.traj_history[j]
            j_rewards = self.observe_feedback(j_trajectory)

            j_policy = self.policy_history[:, :, :, j]
            k_policy = self.policy_history[:, :, :, k]
            
            # r is an S * A * H tensor
            r = self.get_r(j_policy, k_policy)
            # print('r', r)

            # L is an H * 1 tensor
            occupancies = torch.zeros(self.S, self.H)
            L = torch.zeros(self.H)
            # calculate occupanies as well, which is an S * H tensor
            for h in range(self.H-1, -1, -1):
                L[h] = torch.sum(j_rewards[h:])
                occupancies[:, h] = self.get_occupancy(h, j)

            print('occupancies', occupancies[:, 99])
            print('occupancy sum', torch.sum(occupancies[:, 99]))
            print('L', L)
                
            s_j, a_j, _ = self.unzip_trajectory(j_trajectory)

            # Q is an S * A * H * J tensor
            Q[s_j, a_j, :, j] = r[s_j, a_j] * L / (occupancies[s_j] * j_policy[s_j, a_j] + self.gamma)

            print('Q', Q[:, :, 0, 0])

            # b is an S * H tensor
            b = torch.zeros(self.S, self.H)
            for a in range(self.A):
                # print('k_policy shape', k_policy[:, a, :].shape)
                # print('r shape', r.shape)
                # print('occupancies shape', occupancies.shape)
                # print('j_policy shape', j_policy.shape)
                # print('x shape', x.shape)
                b += (3 * self.gamma * self.H * k_policy[:, a, :] * r[:, a, :]) / (occupancies * j_policy[:, a, :] + self.gamma)

            print('b at H-2', b[:, self.H-2])
            print('b at H-3', b[:, self.H-3])
            print('b at 0', b[:, 0])
            # now we calculate B which is S * A * H * J

            for h in range(self.H-2, -1, -1):
                # Now we calculate B for S * A in time h and for episode j

                # B[:, :, h, j] = b[:, h].repeat(self.A, 1).T  
                inner_sum = torch.zeros(self.S, self.A)
                for s in range(self.S):
                    for a in range(self.A):
                        # print('self.transition_function shape', self.transition_function[s, : ,a].repeat(self.A, 1).T.shape)
                        # print('j_policy shape', self.policy_history[:, :, h+1, j].shape)
                        # print('B shape', B[:, :, h+1, j].shape)
                        inner_sum += self.transition_function[s, : ,a].repeat(self.A, 1).T * self.policy_history[:, :, h+1, j] * B[:, :, h+1, j]
                B[:, :, h, j] = b[:, h].repeat(self.A, 1).T + inner_sum

                # Now we update the policy

        # this is not right
        print('B at H-1', B[:, :, self.H-1, 0])
        print('B at H-2', B[:, :, self.H-2, 0])
        print('B at H-3', B[:, :, self.H-3, 0])

        for h in range(self.H):
            for s in range(self.S):
                for a in range(self.A):
                    
                    num_sum = 0.
                    for j in delayed:
                        num_sum += Q[s, a, h, j] - B[s, a, h, j]

                    print('num_sum', num_sum)
                    numerator = self.policy_history[s, a, h, k] * np.exp(-1 * self.eta * num_sum)

                    denominator = 0.

                    for a_prime in range(self.A):
                        inner_denom_sum = 0.
                        for j in delayed:
                            inner_sum += Q[s, a_prime, h, j] - B[s, a_prime, h, j]
                        denominator += self.policy_history[s, a_prime, h, k] * np.exp(-1 * self.eta * inner_sum)

                    print('numerator', numerator)
                    print('denominator', denominator)
                    
                    if k != (self.K - 1): # do not update policy at last episode
                        # self.policy_history[s][a][h][k + 1] = numerator / (denominator if denominator != 0 else 1)
                        self.policy_history[s][a][h][k + 1] = numerator / denominator

    def train(self):
        for k in range(self.K):
            self.train_one_step(k)
        print(self.delay_dict)

In [678]:
DAPO_test = DAPO(config)

In [684]:
DAPO_test.get_occupancy(99, 0)[0]

tensor([1.4013e-45, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00])

In [648]:
DAPO_test.get_one_step_transitions(0)[:, :, 99]

tensor([[6.7497e-01, 8.3333e-04, 8.3333e-04, 8.5338e-02, 8.3333e-04, 8.3333e-04,
         8.3333e-04, 8.3397e-04, 8.4658e-02, 8.3333e-04, 8.3333e-04, 9.2366e-04,
         8.3333e-04, 8.3341e-04, 8.3333e-04, 8.3333e-04, 2.0086e-02, 8.3333e-04,
         1.4896e-03, 8.3333e-04, 8.3333e-04, 3.4823e-03, 8.3333e-04, 8.4229e-02,
         3.1488e-02],
        [6.6758e-03, 6.1377e-01, 6.6667e-03, 6.6667e-03, 6.6667e-03, 6.6667e-03,
         6.6667e-03, 6.6667e-03, 6.7485e-03, 6.6667e-03, 6.6667e-03, 9.1737e-02,
         6.6667e-03, 6.6667e-03, 6.6667e-03, 1.0142e-02, 6.6679e-03, 6.6667e-03,
         4.8768e-02, 6.6667e-03, 6.6667e-03, 9.6945e-02, 6.6667e-03, 1.1875e-02,
         6.6667e-03],
        [6.6828e-03, 6.6667e-03, 7.9442e-01, 6.6667e-03, 6.6667e-03, 6.6667e-03,
         6.6667e-03, 6.6667e-03, 6.6667e-03, 6.6667e-03, 6.6667e-03, 1.0736e-02,
         6.6667e-03, 6.6673e-03, 6.6667e-03, 6.6667e-03, 1.9488e-02, 6.6667e-03,
         6.6667e-03, 6.6667e-03, 6.6667e-03, 9.2708e-03, 6.6667e-

In [632]:
DAPO_test.get_occupancy(99, 0)

tensor([0.1279, 0.0046, 0.0124, 0.0747, 0.0176, 0.0026, 0.0060, 0.0372, 0.0635,
        0.0043, 0.0167, 0.1032, 0.0061, 0.0765, 0.0026, 0.0225, 0.0720, 0.0083,
        0.0438, 0.0046, 0.0143, 0.0683, 0.0056, 0.1142, 0.0906])

In [635]:
DAPO_test.train_one_step(0)

Episode: 0, Total Reward: 77.35842895507812
occupancies tensor([0.1279, 0.0046, 0.0124, 0.0747, 0.0176, 0.0026, 0.0060, 0.0372, 0.0635,
        0.0043, 0.0167, 0.1032, 0.0061, 0.0765, 0.0026, 0.0225, 0.0720, 0.0083,
        0.0438, 0.0046, 0.0143, 0.0683, 0.0056, 0.1142, 0.0906])
occupancy sum tensor(1.0000)
L tensor([77.3584, 76.3754, 75.3947, 74.5676, 73.5869, 73.4024, 73.4024, 73.4024,
        72.4217, 71.4383, 70.4625, 69.8617, 68.8793, 68.4598, 67.4774, 67.4774,
        66.5051, 65.5226, 64.5392, 63.5557, 62.5732, 61.5909, 61.5909, 60.6075,
        59.7804, 59.5665, 58.6782, 58.4937, 57.6054, 56.6218, 55.7335, 54.9323,
        53.9500, 52.9742, 51.9907, 51.7767, 50.7943, 50.1935, 49.2100, 48.2377,
        47.2542, 46.4529, 45.4705, 44.4875, 43.5045, 42.5220, 41.5395, 40.5561,
        39.5727, 38.7457, 37.9016, 36.9293, 35.9535, 35.5340, 34.5510, 33.7239,
        32.7415, 31.7762, 30.7938, 29.8108, 28.8284, 27.8451, 26.8972, 26.0864,
        25.8725, 25.6880, 25.5034, 25.2895, 24.3

  numerator = self.policy_history[s, a, h, k] * np.exp(-1 * self.eta * num_sum)
  denominator += self.policy_history[s, a_prime, h, k] * np.exp(-1 * self.eta * inner_sum)


RuntimeError: expand(torch.FloatTensor{[25, 6]}, size=[]): the number of sizes provided (0) must be greater or equal to the number of dimensions in the tensor (2)

In [572]:
DAPO_test.get_occupancy(99, 0)

tensor([0.1279, 0.0046, 0.0124, 0.0747, 0.0176, 0.0026, 0.0060, 0.0372, 0.0635,
        0.0043, 0.0167, 0.1032, 0.0061, 0.0765, 0.0026, 0.0225, 0.0720, 0.0083,
        0.0438, 0.0046, 0.0143, 0.0683, 0.0056, 0.1142, 0.0906])

In [562]:
DAPO_test.get_one_step_transition(0, 0)[0]

tensor([0.6750, 0.0008, 0.0008, 0.0853, 0.0008, 0.0008, 0.0008, 0.0008, 0.0847,
        0.0008, 0.0008, 0.0009, 0.0008, 0.0008, 0.0008, 0.0008, 0.0201, 0.0008,
        0.0015, 0.0008, 0.0008, 0.0035, 0.0008, 0.0842, 0.0315])

In [489]:
x = torch.tensor([[2, 2], [2, 2]])
torch.linalg.matrix_power(x, 2)

tensor([[8, 8],
        [8, 8]])

In [516]:
torch.sum(DAPO_test.get_one_step_transition(0, 0)[0])


tensor(1.)

In [503]:
x = DAPO_test.get_one_step_transition(99, 0)

In [487]:
y = torch.linalg.matrix_power(x, 99)

In [488]:
y

tensor([[0.1367, 0.0081, 0.0653, 0.0468, 0.0141, 0.0029, 0.0124, 0.0266, 0.0468,
         0.0083, 0.0305, 0.1312, 0.0126, 0.0660, 0.0034, 0.0293, 0.0457, 0.0125,
         0.0246, 0.0087, 0.0217, 0.0462, 0.0175, 0.0984, 0.0834],
        [0.1367, 0.0081, 0.0653, 0.0468, 0.0141, 0.0029, 0.0124, 0.0266, 0.0468,
         0.0083, 0.0305, 0.1312, 0.0126, 0.0660, 0.0034, 0.0293, 0.0457, 0.0125,
         0.0246, 0.0087, 0.0217, 0.0462, 0.0175, 0.0984, 0.0834],
        [0.1367, 0.0081, 0.0653, 0.0468, 0.0141, 0.0029, 0.0124, 0.0266, 0.0468,
         0.0083, 0.0305, 0.1312, 0.0126, 0.0660, 0.0034, 0.0293, 0.0457, 0.0125,
         0.0246, 0.0087, 0.0217, 0.0462, 0.0175, 0.0984, 0.0834],
        [0.1367, 0.0081, 0.0653, 0.0468, 0.0141, 0.0029, 0.0124, 0.0266, 0.0468,
         0.0083, 0.0305, 0.1312, 0.0126, 0.0660, 0.0034, 0.0293, 0.0457, 0.0125,
         0.0246, 0.0087, 0.0217, 0.0462, 0.0175, 0.0984, 0.0834],
        [0.1367, 0.0081, 0.0653, 0.0468, 0.0141, 0.0029, 0.0124, 0.0266, 0.0468,
       

In [504]:
torch.sum(x)

tensor(25.0000)

In [252]:
x = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
y = torch.tensor([[1, 2], [3, 4]])

In [261]:
y[:, 0]

tensor([1, 3])

In [260]:
y.repeat(1, 1, 2)[:, 0, :]

tensor([[1, 2, 1, 2]])

In [250]:
x.repe[:, 0, :]

tensor([[1, 2],
        [5, 6]])

In [247]:
x[:, 0].repeat(2, 1)

tensor([[1, 3],
        [1, 3]])

In [239]:
x[:, 0]

tensor([1, 3])

In [240]:
x

tensor([[1, 2],
        [3, 4]])

In [199]:
x[0]

tensor([1, 2])

In [198]:
torch.sum(x, dim = 1)

tensor([3, 7])

In [201]:
occupancy

tensor([[6.7497e-01, 8.3333e-04, 8.3333e-04, 8.5338e-02, 8.3333e-04, 8.3333e-04,
         8.3333e-04, 8.3397e-04, 8.4658e-02, 8.3333e-04, 8.3333e-04, 9.2366e-04,
         8.3333e-04, 8.3341e-04, 8.3333e-04, 8.3333e-04, 2.0086e-02, 8.3333e-04,
         1.4896e-03, 8.3333e-04, 8.3333e-04, 3.4823e-03, 8.3333e-04, 8.4229e-02,
         3.1488e-02],
        [6.6758e-03, 6.1377e-01, 6.6667e-03, 6.6667e-03, 6.6667e-03, 6.6667e-03,
         6.6667e-03, 6.6667e-03, 6.7485e-03, 6.6667e-03, 6.6667e-03, 9.1737e-02,
         6.6667e-03, 6.6667e-03, 6.6667e-03, 1.0142e-02, 6.6679e-03, 6.6667e-03,
         4.8768e-02, 6.6667e-03, 6.6667e-03, 9.6945e-02, 6.6667e-03, 1.1875e-02,
         6.6667e-03],
        [6.6828e-03, 6.6667e-03, 7.9442e-01, 6.6667e-03, 6.6667e-03, 6.6667e-03,
         6.6667e-03, 6.6667e-03, 6.6667e-03, 6.6667e-03, 6.6667e-03, 1.0736e-02,
         6.6667e-03, 6.6673e-03, 6.6667e-03, 6.6667e-03, 1.9488e-02, 6.6667e-03,
         6.6667e-03, 6.6667e-03, 6.6667e-03, 9.2708e-03, 6.6667e-

In [265]:
DAPO_test.get_occupancy(1, 0)

tensor([0.1077, 0.0243, 0.0140, 0.0836, 0.0313, 0.0458, 0.0150, 0.0656, 0.0819,
        0.0242, 0.0289, 0.0720, 0.0140, 0.0762, 0.0395, 0.0276, 0.0793, 0.0209,
        0.0776, 0.0248, 0.0187, 0.0697, 0.0097, 0.0875, 0.0794])

In [None]:
DAPO_test.play_episode(DAPO_test.policy_history[:,:,:,0])

In [None]:
DAPO_test.get_occupancy(0, 2)

In [150]:
x = torch.tensor([1, 7, 8])
y = torch.tensor([4, 5, 6])

In [151]:
x / (torch.maximum(x, y))

tensor([0.2500, 1.0000, 1.0000])

In [262]:
torch.tensor([1, 2]) * torch.tensor([[3, 4], [5, 6]])

tensor([[ 3,  8],
        [ 5, 12]])

In [117]:
traj_test = DAPO_test.play_episode(DAPO_test.policy_history[:, :, :, 0])

In [140]:
x = torch.tensor([ 0.0000, 64.3088, 54.7057, 64.8670, 58.7536, 52.9981, 64.9772, 12.2039,
        64.9661, 52.9981, 64.9803, 14.1496, 14.1496, 64.9803, 52.9981, 64.9661,
        64.3088, 27.7446, 64.3088, 65.0443, 65.1410, 14.1496, 14.1496, 55.8280,
        64.8670, 64.8670, 62.6964, 64.5440, 65.1419, 62.6964, 64.8670, 64.9661,
        58.7536, 65.0443, 27.7446, 14.1496,  0.0000, 64.8670, 63.8448, 64.9823,
        64.9823, 64.5440, 64.3088, 64.9803,  0.0000, 64.3088, 64.5440, 64.8670,
        14.1496, 14.1496, 39.7341, 65.1410, 64.3088, 65.1419, 63.8448, 14.1496,
        65.1419, 65.0418, 64.9661, 54.7057, 27.7446, 55.8280, 55.8280,  0.0000,
        64.9803, 65.0443, 12.2039, 64.5440, 63.8448, 65.0533, 54.7057, 65.0533,
        64.3088, 63.8448, 64.8670, 53.6274, 58.7536, 64.9661, 65.0173,  0.0000,
        65.1419, 65.0173, 58.7536, 64.5440, 65.0533, 64.9823, 63.8448, 65.1419,
        55.8280,  0.0000, 52.9981, 52.9981, 14.1496, 55.8280, 63.8448, 65.0533,
        54.7057, 12.2039, 62.6964,  0.0000])

torch.sum(x[5:])

tensor(4845.0845)

In [95]:
DAPO_test.policy_history[:, :, 99, 9]

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.166

In [103]:
s_p = torch.tensor()

5
4
3
2
1
