In [155]:
import numpy as np
import os

class Connect3Environment():
    def __init__(self, n_rows, n_cols, value_fun):
        self.n_cols = n_cols
        self.n_rows = n_rows
        self.board = self.initialize_board()
        print(self.board)
        
        self.value_fun = value_fun
    
    def initialize_board(self):
        return np.zeros((self.n_rows, self.n_cols))
    
    def reset(self):
        self.board = self.initialize_board()
        
    def get_all_states(self):
        return [*self.value_fun]
    
    def is_terminal(self, state):
        reshaped_state = np.array(state).reshape(self.n_rows, self.n_cols)
        #print(reshaped_state)
        for row in range(4):
            for col in range(5):
                if reshaped_state[row][col] == 0:
                    continue
                player = reshaped_state[row][col]
                if (col <= 2 and all(reshaped_state[row][col+i] == player for i in range(3))) or \
                   (row <= 1 and all(reshaped_state[row+i][col] == player for i in range(3))) or \
                   (row <= 1 and col <= 2 and all(reshaped_state[row+i][col+i] == player for i in range(3))) or \
                   (row >= 2 and col <= 2 and all(reshaped_state[row-i][col+i] == player for i in range(3))):
                    return True
        # checking for draw also
        return not any(0 in row for row in reshaped_state)
    
    def get_possible_actions(self, state):
        reshaped_state = np.array(state).reshape(self.n_rows, self.n_cols)
        return [col for col in range(5) if reshaped_state[0, col] == 0]
    
    def calculate_turn(self, state):
        reshaped_state = np.array(state).reshape(self.n_rows, self.n_cols)
        player1_count = np.sum(reshaped_state==1)
        player2_count = np.sum(reshaped_state==2)
        
        if player1_count == player2_count:
            return 1
        elif player1_count > player2_count:
            return 2
        else:
            print(reshaped_state)
            return None
    
    def get_next_states(self, state, action):
        #reshaped_state = np.array(state).reshape(self.n_rows, self.n_cols)
        probs = {}
        #self.board = state
        next_state = self.step(state, action)
        valid_actions_for_state = self.get_possible_actions(next_state)
        n_valid = len(valid_actions_for_state)
        
        if len(valid_actions_for_state) == 0:
            return probs
        else:
            for actions in valid_actions_for_state:
                next_next = self.step(next_state, actions)
                next_state_tuple = tuple(next_next.flatten())
                probs[next_state_tuple] = 1/n_valid
        
        return probs                
            
        
    def step(self, state, action):
        reshaped_board = np.array(state).reshape(self.n_rows, self.n_cols)
        player = self.calculate_turn(reshaped_board)
        
        for row in range(3, -1, -1):
            if reshaped_board[row, action] == 0:
                reshaped_board[row, action] = player
                break
        
        #self.board = reshaped_board
        
        return reshaped_board
    
    def get_reward(self, state, action, next_state):
        reshaped_state = np.array(state).reshape(self.n_rows, self.n_cols)
        assert action in self.get_possible_actions(reshaped_state)
        
        # draw
        if not any(0 in row for row in reshaped_state):
            return 0
        # current state is terminal - lost game
        elif self.is_terminal(state):
            return -1
        # next_state is terminal - game ended after our move
        elif self.is_terminal(next_state):
            return 1
        else:
            return 0

In [None]:
import pickle

with open(f"value_function_passive.pkl", "rb") as f:
    value_fun = pickle.load(f)
    print('Value function loaded from the file.')

Value function loaded from the file.


In [157]:
filtered_value_function = {}
for key in value_fun.keys():
    reshaped_board = np.array(key).reshape(4, 5)
    player1_count = np.sum(reshaped_board==1)
    player2_count = np.sum(reshaped_board==2)
        
    if player1_count == player2_count:
        filtered_value_function[key] = 0
            
print(f"Filtered length: {len(filtered_value_function)}")

Filtered length: 292495


In [158]:
env = Connect3Environment(4, 5, filtered_value_function)

[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]


In [159]:
def value_iteration(mdp, gamma, theta):
    """
            This function calculate optimal policy for the specified MDP using Value Iteration approach:

            'mdp' - model of the environment, use following functions:
                get_all_states - return list of all states available in the environment
                get_possible_actions - return list of possible actions for the given state
                get_next_states - return list of possible next states with a probability for transition from state by taking
                                  action into next_state
                get_reward - return the reward after taking action in state and landing on next_state


            'gamma' - discount factor for MDP
            'theta' - algorithm should stop when minimal difference between previous evaluation of policy and current is
                      smaller than theta
            Function returns optimal policy and value function for the policy
       """
    V = dict()
    policy = dict()

    # init with a policy with first avail action for each state
    for current_state in mdp.get_all_states():
        V[current_state] = 0
        if len(mdp.get_possible_actions(current_state))>0:
            policy[current_state] = mdp.get_possible_actions(current_state)[0]
        else:
            continue

    #
    # INSERT CODE HERE to evaluate the best policy and value function for the given mdp
    #
    while True:
        delta = 0
        for s in mdp.get_all_states():
            v = V[s]
            action_values = []
            for action in mdp.get_possible_actions(s):
                action_value = 0
                for sprim in mdp.get_next_states(s, action):
                    #print(V[sprim])
                    action_value = action_value + mdp.get_next_states(s, action)[sprim] * (mdp.get_reward(s, action, sprim)+gamma*V.get(sprim, 0))
                action_values.append([action_value, action])
            
            action_values.sort(key=lambda x: x[0], reverse=True)
            
            if len(action_values)>0:
                delta = max(delta, abs(V[s] - action_values[0][0]))
                V[s] = action_values[0][0]
                policy[s] = action_values[0][1]
            else:
                continue
            
        print(delta)
        if delta < theta:
            break
    return policy, V

In [160]:
optimal_policy, optimal_value = value_iteration(env, 0.9, 0.001)

1.9
1.17
0.405
0.235
0.03644999999999998
0.0026818087500000143
0.00034992000000000356


In [None]:
with open(f"passive_learning_policy.pkl", "wb") as fw:
    pickle.dump(optimal_policy, fw)

In [None]:
with open(f"C:passive_learning_value_fun.pkl", "wb") as fw:
    pickle.dump(optimal_value, fw)