In [None]:
#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
#in this code let's use train an AI to play a game of battleship
#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@

In [None]:
#@@@@@@@@@@@@@@@@@@@@@@@
#import useful libraries
#@@@@@@@@@@@@@@@@@@@@@@@
import numpy as np
import copy
import seaborn as sb

from tqdm import tqdm
from IPython.display import clear_output
import time
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

import pickle


In [None]:
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
# two classes for these two environments
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
class Battleship:
    """A class to manage the board"""

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    def __init__(self, config):
        """Set up the constructor
        Takes -- config, a dictionary specifying the track dimensions and initial state
        """
        self.config = config
        self.hit_streak = 0
        self.true_state = np.zeros([config["n"], config["n"]], dtype=int)
        self.state = np.zeros(
            [config["n"], config["n"]], dtype=int
        )  # TODO -- consider how to change this
        self.old_state = copy.deepcopy(self.state)

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    def position_maker(self, ship):
        """A function to choose a position for a ship
        Takes:
            self -- class instance info above
            ship -- int, the index in the list of ship lengths
        Returns:
            an int, int pair, the coordinates of the first cell the ship occupies
        """
        coord_1 = np.random.choice(
            list(range(self.config["n"] - self.config["ships"][ship] + 1))
        )  # choose first coordinate, guaranteed to have ship on board
        coord_2 = np.random.choice(
            list(range(self.config["n"]))
        )  # choose second coordinate
        return coord_1, coord_2

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    def check_empty(self, coord_1, coord_2, vert, ship):
        """A function to make sure the proposed location is currently empty
        Takes:
            self -- class instance info above
            coord_1 -- int, the first coordinate, constrained to ensure the ship lies on the board
            coord_2 -- int, the second coordinate
            vert -- boolean, whether to orient the ship vertically or horizontally
            ship -- int, the index in the list of ship lengths
        Returns:
            a boolean, indicating if the placement is legal
        """
        if vert:  # if the ship is oriented vertically...
            if (
                sum(
                    self.true_state[
                        coord_1 + list(range(self.config["ships"][ship])), coord_2
                    ]
                )
                > 0
            ):  # and if there is another ship lying across its path...
                empty = False  # set the empty indicator to false
            else:  # otherwise if the proposed location is empty...
                empty = True  # set the empty indicator to true
        else:  # if the ship is oriented horizontally...
            if (
                sum(
                    self.true_state[
                        coord_2, coord_1 + list(range(self.config["ships"][ship]))
                    ]
                )
                > 0
            ):  # and if there is anther ship lying across its path...
                empty = False  # set the empty indicator to false
            else:  # otherwise if the proposed location is empty...
                empty = True  # set the empty indicator to true
        return empty

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    def place_ship(self, ship):
        """A function to put a single ship on the board
        Takes:
            self -- class instance info above
            ship -- int, the index in the list of ship lengths
        """
        empty = False  # init the stopping criterion
        while not empty:  # while we should keep going...
            coord_1, coord_2 = self.position_maker(
                ship
            )  # sample a position for the ship
            vert = (
                np.random.uniform() < 0.5
            )  # choose whether ship is horizontal or vertical
            empty = self.check_empty(coord_1, coord_2, vert, ship)  # check if legal
        if vert:  # if the ship is vertical...
            self.true_state[
                coord_1 + list(range(self.config["ships"][ship])), coord_2
            ] = 1  # place on board
        else:  # if the ship is horizontal...
            self.true_state[
                coord_2, coord_1 + list(range(self.config["ships"][ship]))
            ] = 1  # place on board
        vert = (
            np.random.uniform() < 0.5
        )  # choose whether ship is horizontal or vertical

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    def make_board(self):
        """A function to put all ships on the board
        Takes:
            self -- class instance info above
        """
        self.true_state = np.zeros([self.config["n"], self.config["n"]], dtype=int)
        for j in range(len(self.config["ships"])):  # loop over ship lengths
            self.place_ship(j)  # put on board'

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    def reset(self):
        """A function to reset the board
        Takes:
            self -- class instance info above
        """
        self.hit_streak = 0
        self.true_state = np.zeros(
            [self.config["n"], self.config["n"]], dtype=int
        )  # reset the hidden state
        self.make_board()  # place ships again
        self.state = np.zeros(
            [self.config["n"], self.config["n"]], dtype=int
        )  # reset the observed state
        self.old_state = copy.deepcopy(self.state)

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    def check_win(self):
        """A function to check whether the agent has won
        Takes:
            self -- class instance info above
        Returns:
            a boolean, indicating if the agent sank all the ships
        """
        if np.sum(self.state > 0) == sum(
            self.config["ships"]
        ):  # if the number of hits is equal to the total length of ships...
            return True  # return true -- the agent won!
        else:  # otherwise...
            return False  # return false -- the game is still going

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    def reward(
        self,
    ):
        """
        A function to compute the reward at each step
            Takes:
                self -- class instance info above
            Returns:
                r -- float, the reward at this step
        """
        # TODO
        total_hits = np.sum(self.state == 1)
        new_hits = total_hits - np.sum(self.old_state == 1)
        weights = [
            1,  # reward for a new hit
            1.5,  # reward for new hits
            2,  # reward for hit streaks
        ]
        inputs = [total_hits, new_hits, self.hit_streak]
        return np.dot(weights, inputs)

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    def step(self, a):
        """A function to update the state given the true enemy fleet locations"""
        if (
            self.true_state[a["row"], a["col"]] > 0
        ):  # if there is a ship at the chosen location...
            self.state[a["row"], a["col"]] = (
                1  # record a hit on the observed state as a 1
            )
            self.hit_streak += 1
        else:  # if there is no ship
            self.state[a["row"], a["col"]] = (
                -1
            )  # record a miss on the observed state as a -1
            self.hit_streak = 0
        done = self.check_win()  # check to see if the agent has won
        if done:
            r = 10  # max(10*np.sum(self.state == 0),10)
        else:
            r = (
                self.reward()
            )  # WILL NOT WORK RIGHT NOW -- NEED TO SPECIFIY REWARD FUNCTION
        self.old_state = copy.deepcopy(self.state)
        return {"state": self.state, "reward": r, "done": done}

In [None]:
#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
#A simple CNN architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        # Input: (batch_size, 1, 6, 6)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)  # -> (32, 6, 6)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # -> (64, 6, 6)
        self.conv3 = nn.Conv2d(64, 1, kernel_size=1)  # -> (1, 6, 6)

    def forward(self, x):
        x = F.relu(self.conv1(x))  # (batch_size, 32, 6, 6)
        x = F.relu(self.conv2(x))  # (batch_size, 64, 6, 6)
        q_map = self.conv3(x)      # (batch_size, 1, 6, 6)
        return q_map.squeeze(1)    # -> (batch_size, 6, 6)


In [None]:
#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
class Agent:
    '''A class to manage the chase'''
    #@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    def __init__(self, config):
        '''Set up the constructor
            Takes -- config, a dictionary specifying the track dimensions and initial state
        '''
        self.config = config    #save the config file
        self.A_s = self.config['A']    #save the action set
        #init the deep learning model
        #copy the target approximator
        #init the replay buffer

    #@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    def update_Q_target(self,):
        '''A function to update the target approximator
        '''
        #TODO
        
    #@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    def purge_replay_buffer(self,):
        '''A function to keep the replay buffer within memory limits
        '''
        #TODO
        
    #@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    def forward_pass(self,):
        '''A function to compute the forward pass
        '''
        #TODO
        
    #@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@    
    def pi(self,):
        '''A function to choose actions using Q-values
        '''
        #TODO
        
    #@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    def make_batch(self):
        '''A function to make a batch for updating Q
        '''
        #TODO

    #@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    def update_Q_value(self,):
        '''A function to use the batch to estimate the gradient and take a single step
        '''
        #TODO
        

In [None]:
#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
#Set up the classes -- the environment and the agent
#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
env_config = {'n':6,'ships':[2,3,4],'boards':None}    #set up the environmental config

A = []    #init the actions
for row in range(6):    #loop over vertical velo changes
    for col in range(6):    #loop over horizonal velo changes
        A.append([row,col])    #record

agent_config = {'gamma':0.9    #the discount factor
                ,'epsilon':0.1    #the epsilon-greedy parameter
                ,'alpha':0.001    #the learning rate
                ,'A':A    #the action set
                ,'n':6    #the size of the board
                ,'M':100000    #set the memory size
                ,'B':32    #set the batch size
                ,'C':500    #when to update the target approximator
                ,'n_steps_for_Q_update':4    #the number of steps to use to update
               }   #examples of agent config

board = Battleship(env_config)    #set up the environment for testing agent performance
agent = Agent(agent_config)    #set up the agent


In [None]:
#@@@@@@@@@@@@@
#Training Loop
#@@@@@@@@@@@@@
N_eps = #TODO
for epi in tqdm(range(N_eps)):    #loop over episodes
    board.make_board()    #set up the board
    done = False    #set the stopping condition
    turn = 0    #init the turn count
    while not done:#for _ in range(6):    #while the episode goes on...

        #TODO
        #generate an action
        
        #init the tuple for this turn
        #save the old state
        #save the action
        #evolve the environment
        #save the reward
        #save the new state
        
        #add this tuple to the replay buffer
        #remove the oldest tuple if appropriate
        
        #record whether we've stopped

        #update the Q values
        #update the target approximator

    board.reset()    #reset the board
    