In [None]:
### Install Dependencies ###
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from torch.utils.tensorboard import SummaryWriter


In [None]:
def gaussian_kernel(x, y, sigma=1.0):
    # Compute pairwise squared Euclidean distances
    dist = torch.cdist(x, y, p=2, compute_mode="donot_use_mm_for_euclid_dist")

    # Compute Gaussian kernel matrix
    kernel = torch.exp(-torch.pow(dist, 2) / (2 * sigma ** 2))

    return kernel

def mmd_loss(x, y, sigma=1.0):
    n = x.size(0)
    m = y.size(0)

    # Compute kernel matrices
    xx = gaussian_kernel(x, x, sigma)
    yy = gaussian_kernel(y, y, sigma)
    xy = gaussian_kernel(x, y, sigma)

    # Compute MMD loss
    loss = (torch.sum(xx) / (n * (n - 1))) + (torch.sum(yy) / (m * (m - 1))) - (2 * torch.sum(xy) / (n * m))

    return loss



In [None]:
class RBF(nn.Module):

    def __init__(self, n_kernels=5, mul_factor=2.0, bandwidth=None):
        super().__init__()
        self.bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2)
        self.bandwidth = bandwidth

    def get_bandwidth(self, L2_distances):
        if self.bandwidth is None:
            n_samples = L2_distances.shape[0]
            return L2_distances.data.sum() / (n_samples ** 2 - n_samples)

        return self.bandwidth

    def forward(self, X):
        L2_distances = torch.cdist(X, X) ** 2
        return torch.exp(-L2_distances[None, ...] / (self.get_bandwidth(L2_distances) * self.bandwidth_multipliers)[:, None, None]).sum(dim=0)


class MMDLoss(nn.Module):

    def __init__(self, kernel=RBF()):
        super().__init__()
        self.kernel = kernel

    def forward(self, X, Y):
        K = self.kernel(torch.vstack([X, Y]))

        X_size = X.shape[0]
        XX = K[:X_size, :X_size].mean()
        XY = K[:X_size, X_size:].mean()
        YY = K[X_size:, X_size:].mean()
        return XX - 2 * XY + YY

In [None]:
writer = SummaryWriter()
WIN = 1.0
LOSS = -1.0
DRAW = 0.5
EPS = 0.7

In [None]:
class Board():
  def __init__(self):
    self.state =  torch.zeros(9)
    self.marker = torch.tensor(1.0)
    self.game_status = torch.tensor(0.0)
    self.result = torch.tensor(0.0)
    self.eps = EPS
  def reset(self):
    self.state =  torch.zeros(9)
    self.marker = torch.tensor(1.0)
    self.game_status = torch.tensor(0.0)
    self.result = torch.tensor(0.0)
  def play_move(self,pos):
    if(self.state[pos]!=0.0):
        raise Exception("You made an illegal move")

    self.state[pos]=self.marker
    self.marker = self.marker *-1.0
  def get_status(self):

    winning_combinations = [[0,1,2],[3,4,5],[6,7,8],
                            [0,3,6],[1,4,7],[2,5,8],
                            [0,4,8],[2,4,6]]

    for combination in winning_combinations:
           if(self.state[combination[0]]==self.state[combination[1]]==self.state[combination[2]]!=0.0):
               self.game_status = torch.tensor(1.0)
               self.result = self.state[combination[0]]
    if(torch.count_nonzero(self.state)==9.0):
        self.game_status = torch.tensor(1.0)
        self.result = torch.tensor(0.5)
    return self.game_status,self.result
  def get_mask(self):
     mask = self.state==0.0
     return mask









In [None]:
class TicTacToeMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(TicTacToeMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class TicTacToePolicy:
    def __init__(self):
        self.model = TicTacToeMLP(9, 128, 9)  # Input: 9 (3x3 board), Hidden: 128, Output: 9 (actions)
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=1e-2)
        self.eval = False
        self.state = False

    def get_action_probabilities(self, board):
        if(self.state==True):
          inp = board
        else:
           inp = board.state

        logit = self.model(inp.clone())


        # Mask out illegal actions
        if(self.state==True):
           mask = board==0
        else:
           mask =   board.get_mask() # Mask: 1 for legal actions, 0 for illegal actions

        mask = mask.float()

        logits = logit - 1e9 * (1 - mask)  # Apply large negative values to illegal actions
        probabilities = F.softmax(logits, dim=0)
        if(eval):
          move = torch.argmax(probabilities)
        m =   torch.distributions.categorical.Categorical(probs=probabilities)
        move = m.sample()

        step = torch.count_nonzero(mask)
        if(self.state==False):
           eps_temp = board.eps*(1-step/12)
        else:
           eps_temp=0.03

        if(random.random()<eps_temp):
          valid_moves = np.where(mask==1.0)[0]
          move = np.random.choice(valid_moves)
          move = torch.tensor(move)

        log_prob = m.log_prob(move)
        return log_prob,move.item(),probabilities

In [None]:
class RandomPlayer:
    def __init__(self):
        pass

    def get_action(self, board):
        # Get a list of available actions
        pos = np.where(board.get_mask()==1)[0]
        move = np.random.choice(pos)


        return move


In [None]:
board = Board()
player_1 = TicTacToePolicy()
player_2 = TicTacToePolicy()
player_1.model.load_state_dict(torch.load("/content/drive/MyDrive/policy60000.pt"))
#player_2.model.load_state_dict(torch.load("/content/drive/MyDrive/policy.pt"))

player_3 = RandomPlayer()


def games(strategy="target"):
  states = []
  log_prob = []
  if(strategy=="target"):
    player = player_1
  else:
    player = player_2

  board.reset()
  while True:
    states.append(board.state)
    prob,move,actions = player.get_action_probabilities(board)
    log_prob.append(prob)

    board.play_move(move)
    status,reward =board.get_status()
    if(status==1.0):
       break

    move = player_3.get_action(board)

    board.play_move(move)
    status,reward =board.get_status()
    if(status==1.0):
       break
  return states,reward,log_prob





In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!rm -r runs

In [None]:
###  Define Training Loop ##
num_games = 1000000
board.eps=0.3
player_2.optimizer.zero_grad()
vector = []
reward_diff = []
criterion = torch.nn.MSELoss()
fn =MMDLoss()
reward_a =0
reward_b =0
for playouts in range(num_games):

  if(len(vector)>10000):
    vector=[]

  states,reward,log_prob= games()

  states_1,reward_1,log_prob_1 = games("neural")
  log_prob_tensor = torch.stack(log_prob_1)
  reward_scaled=  (reward_1-reward)**2
  reward_tensor = torch.ones(log_prob_tensor.shape)*reward_scaled
  grad_tensor = ((log_prob_tensor)*(reward_scaled)).sum()  ## Note the minus is not there because we are minimizing
  grad_tensor.backward()
  board.eps = board.eps-board.eps/(num_games-10000)

  pick = np.random.choice(len(states_1))
  if(torch.count_nonzero(states_1[pick])<9):
            vector.append(states_1[pick])

  if((playouts+1)%512==0):
     indices = np.random.permutation(np.arange(len(vector)))
     indices = indices[0:4]
     actions_1 = []
     actions_2 = []
     vec = torch.stack(vector)
     player_1.state= True
     player_2.state= True

     for entry in indices:
         _,_,a1 = player_1.get_action_probabilities(vec[entry])

         _,_,a2 = player_2.get_action_probabilities(vec[entry])
         actions_1.append(a1)
         actions_2.append(a2)
     p1_actions = torch.stack(actions_1)
     p2_actions = torch.stack(actions_2)
     player_1.state= False
     player_2.state= False


     loss = fn(p1_actions.detach(),p2_actions)
     loss = -loss
     loss.backward()




  if((playouts+1)%512==0):


     player_2.optimizer.step()
     player_2.optimizer.zero_grad()
     writer.add_scalar('Expected difference', reward, playouts)




  if(playouts%10000==0):
     player_2.eval=True
     loss = 0
     draw = 0
     win = 0
     for i in range(1000):
      _,reward,_ = games("neural")
      if(reward==1.0):
        win+=1
      elif(reward==-1.0):
        loss+=1
      else:
        draw+=1







     writer.add_scalar("Win_percentage_random", win*100,playouts)
     writer.add_scalar("Loss_percentage_random",loss*100,playouts)
     writer.add_scalar("Draw_percentage_random",draw*100,playouts)
     print("Evaluation after",playouts,"games")
     print("Win",win/1000)
     print("Loss",loss/1000)
     print("Draw",draw/1000)
     player_1.eval=False
### Final testing against minimax,minimx_random and random ###
player_1.eval=True
loss = 0
draw = 0
win = 0
for i in range(1000):
    _,reward,_ = games(player="random")
    if(reward==1.0):
      win+=1
    elif(reward==-1.0):
      loss+=1
    else:
      draw+=1

torch.save(player_2.model.state_dict(),"/content/drive/MyDrive/MMD_60k.pt")










Evaluation after 0 games
Win 0.386
Loss 0.278
Draw 0.336
Evaluation after 10000 games
Win 0.474
Loss 0.203
Draw 0.323
Evaluation after 20000 games
Win 0.459
Loss 0.211
Draw 0.33
Evaluation after 30000 games
Win 0.435
Loss 0.222
Draw 0.343
Evaluation after 40000 games
Win 0.468
Loss 0.209
Draw 0.323
Evaluation after 50000 games
Win 0.464
Loss 0.206
Draw 0.33
Evaluation after 60000 games
Win 0.46
Loss 0.212
Draw 0.328
Evaluation after 70000 games
Win 0.454
Loss 0.192
Draw 0.354
Evaluation after 80000 games
Win 0.451
Loss 0.193
Draw 0.356
Evaluation after 90000 games
Win 0.465
Loss 0.215
Draw 0.32
Evaluation after 100000 games
Win 0.474
Loss 0.207
Draw 0.319
Evaluation after 110000 games
Win 0.448
Loss 0.201
Draw 0.351
Evaluation after 120000 games
Win 0.481
Loss 0.187
Draw 0.332
Evaluation after 130000 games
Win 0.48
Loss 0.195
Draw 0.325
Evaluation after 140000 games
Win 0.478
Loss 0.197
Draw 0.325
Evaluation after 150000 games
Win 0.468
Loss 0.218
Draw 0.314
Evaluation after 160000 gam

KeyboardInterrupt: ignored

In [None]:
board.reset()
board.play_move(4)
board.play_move(1)




_,_,prob_1 = player_1.get_action_probabilities(board)
_,_,prob_2 = player_2.get_action_probabilities(board)

print(torch.argmax(prob_1),torch.argmax(prob_2))


tensor(8) tensor(0)


In [None]:
torch.save(player_2.model.state_dict(),"/content/drive/MyDrive/MMD_60K.pt")

In [None]:
!cp -r runs_mmd /content/drive/MyDrive