In [1]:
import numpy as np
import random
import copy
import time
from collections import Counter
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import rcParams
from IPython.display import Image, display, HTML
from IPython.display import clear_output

%matplotlib inline

### Functions for Mahjong state interaction

In [2]:
mahjong_dict = {0:'1 Tong',1:'2 Tong',2:'3 Tong',3:'4 Tong',4:'5 Tong',5:'6 Tong',6:'7 Tong',7:'8 Tong',8:'9 Tong',
              9:'1 Wan',10:'2 Wan',11:'3 Wan',12:'4 Wan',13:'5 Wan',14:'6 Wan',15:'7 Wan',16:'8 Wan',17:'9 Wan',
              18:'1 Suo',19:'2 Suo',20:'3 Suo',21:'4 Suo',22:'5 Suo',23:'6 Suo',24:'7 Suo',25:'8 Suo',26:'9 Suo',
              27:'Dong',28:'Nan',29:'Xi',30:'Bei',
              31:'Bai_Ban',32:'Fa_Cai',33:'Hong_Zhong'}

def initialize_mahjong():
    """
    Creates a random initialization of a game of Mahjong (14 tiles)
    Returns 2D array representing the state
    Axis 0: Index denotes tile (see mahjong_dict)
    Axis 1: Index denotes copy number of that tile
    Values: 0 = Draw Pile, 1 = Player's Hand, 10 = Discard Pile
    """
    random_seed = random.sample(np.arange(33*4).tolist(),14)
    state = np.zeros([33,4])
    for i in random_seed:
        # int(i/4) gives tile index, i%4 gives copy index
        state[int(i/4),i%4] = 1
    return state

def discard_tile(state,action):
    """
    Returns a the 2D array indices of an action (tile to discard).
    Action = value between 0 and 13
    Helper function for update_state
    """    
    tile_ind = [i for i, x in enumerate(state.flatten()) if x == 1]
    # e.g. tile_ind [1, 8, 10, 11, 27, 39, 72, 73, 82, 88, 98, 115, 118]
    # action = 2
    # tile_ind[action] = 10
    return int(tile_ind[action]/4), tile_ind[action]%4

# Select a random index that is of value 0 (draw_pile) to the tile to discard in that state
def draw_tile(state):
    """
    Returns a the 2D array indices of a randomly drawn tile.
    Helper function for update_state
    """
    tile_ind = [i for i, x in enumerate(state.flatten()) if x == 0]
    rand_tile = random.choice(tile_ind)
    return int(rand_tile/4), rand_tile%4

# Discard and randomly draw a tile
def update_state(state,action):
    """
    Input current state and action (0-13).
    Discards the tile specified by the action and randomly draws a new tile.
    Discarded tile is moved to the discard pile.
    """
    x,y = discard_tile(state,action)
    state[x,y] = 10 # discard pile = 10
    x,y = draw_tile(state)
    state[x,y] = 1 # player hand = 1
    return state

def is_winning_hand(state):
    """
    Check if a hand is a terminal state (valid winning hand)
    Uses ShangTingDistance function to help check
    """    
    if getShangTingDistance(state.copy()) == 0:
        return True # True or False
    else:
        return False

def state_to_string(state):
    """
    Converts a 2D array to string, for state dictionary comparison purposes (not Mahjong dict)
    """
    state_string = ''
    state[state==10] = 2 
    for i in state.flatten():
        state_string += str(int(i))
    return state_string

def number_of_discards(state):
    return np.count_nonzero(state == 10)

### Functions to Display Tiles

In [3]:
def display_picture(deck, mahjong_dict=mahjong_dict, all_Rsa=None, display_discarded={}):
    """
    Displays a game state.
    If all_Rsa is given, then the R(s,a) for each action (tile) will be shown
    If display_discarded=True, then the discard pile will also be displayed
    """
    hand = deck.copy()
    hand[hand == 10] = 0
    # Obtain a list of image filenames
    filenames = display_picture_helper(hand, mahjong_dict)
    # Create the image_list using a loop
    image_list = []
    
    if all_Rsa is None:
        print('=================================== Player Hand  ===================================')
        for filename in filenames:
            image = {'image_path': filename}
            image_list.append(image)
        images_html = ''
        for image in image_list:
            images_html += f'<figure style="display:inline-block;"><img src="{image["image_path"]}" width="50"/></figure>'
        display(HTML(images_html))
    else:
        print('========================== Player Hand and Discard Ratings ==========================')
        all_Rsa_rounded = [round(Rsa,3) for Rsa in all_Rsa]
        for i in range(len(filenames)):
            image = {'image_path': filenames[i], 'caption': all_Rsa_rounded[i]}
            image_list.append(image)
        # Loop through the list and display each image with its caption
        images_html = ''
        for image in image_list:
            images_html += f'<figure style="display:inline-block;"><img src="{image["image_path"]}" width="50"/><figcaption>{image["caption"]}</figcaption></figure>'
        display(HTML(images_html))
        
        # Display recommended tile to discard
        print('=== Recommended Discard Tile ===')
        tile_to_discard = choose_from_Rsa(all_Rsa)
        filename = filenames[tile_to_discard]
        Rsa_rounded = all_Rsa_rounded[tile_to_discard]
        images_html = f'<figure style="display:inline-block;"><img src="{filename}" width="50"/><figcaption>{Rsa_rounded}</figcaption></figure>'
        display(HTML(images_html))
    
    "Print discarded pile"
    discarded = deck.copy()
    discarded[discarded == 1] = 0
    discarded[discarded == 10] = 1
    if display_discarded == True:
        if discarded.any():
            print('\n')
            print('================================== Discard Pile ==================================')
            filenames = display_picture_helper(discarded, mahjong_dict)
            image_list = []
            for filename in filenames:
                image = {'image_path': filename}
                image_list.append(image)
            images_html = ''
            for image in image_list:
                images_html += f'<figure style="display:inline-block;"><img src="{image["image_path"]}" width="50"/></figure>'
            display(HTML(images_html))
    
def display_picture_helper(hand, mahjong_dict):
    """
    Given a 2D game_state
    Returns list of file name strings.
    Helper function for display_picture
    """
    to_print = []
    row_sums = np.sum(hand, axis=1)
    for idx, num in enumerate(row_sums):
        # reach into dict to extract tile name
        tile_name = mahjong_dict[idx].lower()
        for i in range(int(num)):
            to_print.append(f'tilePictures/{tile_name}.jpg')
    return to_print

### Compute the state-action reward, i.e R(s,a)

In [4]:
def get_Rsa(original_state,action):
    state = original_state.copy()
    base_ShangTing = getShangTingDistance(original_state.copy())
    x,y = discard_tile(state,action)
    state[x,y] = 10    
    tile_ind = [i for i, x in enumerate(state.flatten()) if x == 0]
    Rsa = 0
    for possible_tile in tile_ind:
        temp_state = state.copy()
        x,y = int(possible_tile/4), possible_tile%4
        temp_state[x,y] = 1
        Rsa += getShangTingDistance(temp_state.copy())
    return Rsa/len(tile_ind) - base_ShangTing

def get_all_Rsa(original_state):
    state = original_state.copy()
    all_Rsa = []
    for action in range(14):
        all_Rsa.append(get_Rsa(state.copy(),action))
    return all_Rsa

def choose_from_Rsa(all_Rsa):
    """
    Selects the argmax of Rsa (action that maximizes utility in the next state)
    If there are multiple equal max Rsa values, choose from these at random
    """    
    return np.random.choice(np.flatnonzero(all_Rsa == np.max(all_Rsa)))

### ShangTing Distance Function

In [5]:
def ignoreTile(tile_array):
    """
    Ignores a tile. Helper function for getShangTingDistance. Given 1x4 array, set one '1' to '0'
    Eg: ignoreTile([1,0,1,1]) gives [0,0,1,1]
    """
    for pos in range(len(tile_array)):
        if tile_array[pos] == 1:
            tile_array[pos] = 0
            break
    return tile_array

def getShangTingDistance(game_state):
    """
    Given a 2D game_state array that is not a winning hand, 
    return the ShangTing distance function defined in
    https://ieeexplore.ieee.org/abstract/document/10033435.
    
    The ShangTing distance defines distance away from basic winning hand.
    Max distance is 14. For each triple identical tiles or three consecutive tiles of same suite,
    distance decreases by 3. First pair reduces distance by 2, subsequent pairs are worthless.
    A score of 0 denotes a winning hand. 
    Update March 11: now works with discarded hand.
    """
    # switch values where positions are 10 (in discard pile) to 0 (just for the purposes of computing hand score)
    game_state = np.where(game_state==10, 0, game_state)
    
    # Check for identical tiles =====================
    row_sums = np.sum(game_state, axis=1)
    triplet_mask = row_sums > 2  # use boolean indexing to select the rows with sum greater than 2
    num_iden_melds = triplet_mask.sum()  
    # Then remove 3 tiles from that array, since we can't count that for consecutive anymore
    triplet_indices = [i for i, x in enumerate(triplet_mask) if x]
    for idx in triplet_indices:
        for i in range(3):
            game_state[idx] = ignoreTile(game_state[idx])

    # Check for consecutive tiles =======================
    # Segment the game states into tiles of the same suite. tongs: indx 0 to 8, wans: indx 9 to 17 inclusive
    tongs, wans, suos = game_state[:9], game_state[9:18], game_state[18:27]
    total_triplets_count = 0
    for suite in tongs, wans, suos:
        streak = 0
        previous = 0
        triple_consecutive_indices = []
        triple_consecutive_count = 0
        for i in range(len(suite)):
            row = suite[i]
            # row is gonna be like [1, 0, 0, 0] 
            if sum(row) > 0:
                # 2 in a row
                if previous == 1 and streak == 1:
                    streak = 2
                    # leave previous as 1
                elif streak == 2:
                    # 3 in a row. update and remove tiles from further consideration
                    streak = 0
                    previous = 0
                    triple_consecutive_count += 1
                    # ignore this tile and 2 previous ones
                    for j in range(3):
                        indx_to_ignore = i - j
                        suite[indx_to_ignore] = ignoreTile(suite[indx_to_ignore])
                else: 
                    # if previous was 0
                    previous = 1
                    streak = 1
            else:
                previous = 0
                streak = 0
        total_triplets_count += triple_consecutive_count

    # Check for a pair from the remaining tiles. game_state has been updated. =====
    remaining_row_sums = np.sum(game_state, axis=1)
    pair_mask = remaining_row_sums > 1
    has_pair_score = 0
    if any(pair_mask):
        has_pair_score = 2

    shangtingDistance = - 14 + total_triplets_count * 3 + num_iden_melds * 3 + has_pair_score
    # print(f"# identical triplets: {num_iden_melds}  |  # consecutive threes: {total_triplets_count}  |  has pair: {bool(has_pair_score)}")
    return shangtingDistance

### Solve the MDP

In [162]:
# Initialize a new random Mahjong game
game_state = initialize_mahjong()

while not is_winning_hand(game_state.copy()):  
    # Compute R(s,a) for all actions
    all_Rsa = get_all_Rsa(game_state.copy())

    # Identify the optimal tile to discard
    tile_to_discard = choose_from_Rsa(all_Rsa)

    # Display player's hand & the corresponding R(s,a), also displays discarded tile if requested
    display_picture(game_state.copy(),all_Rsa=all_Rsa,display_discarded=True)

    # Update the Mahjong game state (discard + random draw)
    game_state = update_state(game_state,tile_to_discard)
    
    # Adjust how long to sleep after each discard so that you can watch
    clear_output(wait=True)
    time.sleep(1)

print("Winning Mahjong hand obtained!")
print("Total number of discards = "+str(number_of_discards(game_state.copy())))
display_picture(game_state.copy(),display_discarded=True)    

Winning Mahjong hand obtained!
Total number of discards = 38






### Run for 1000 Mahjong games to collect statistics

In [17]:
n_discards = []
n_runtime = []

for n_games in range(1000):
    game_state = initialize_mahjong()
    start_time = time.time()
    while not is_winning_hand(game_state.copy()):  
        all_Rsa = get_all_Rsa(game_state.copy())
        tile_to_discard = choose_from_Rsa(all_Rsa)
        game_state = update_state(game_state,tile_to_discard)
    clear_output(wait=True)
    print(n_games)
    n_discards.append(number_of_discards(game_state.copy()))
    n_runtime.append(time.time()-start_time)

999
