# 1.验证模型结构

In [1]:
import torch
import numpy as np
# 检测GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

# 测试用输入
batch =  []
batch_size = 32
feature_channels = 34
feature_height = 4
feature_width = 34

action_space = 49

seq_lens = []
for i in range(batch_size):
    seq_lens.append(np.random.randint(1, 129))

for seq_len in seq_lens:
    
    action_list = [i for i in range(seq_len)]
    self_action_mask = np.random.randint(0, 2, (seq_len))
    num_self_actions = np.sum(self_action_mask)
    legal_action_mask = np.random.randint(0, 2, (num_self_actions, action_space))

    info = {
        'scores': np.random.rand(num_self_actions, 4),
        'oya': np.random.randint(0, 4, (num_self_actions,)),
        'honba_riichi_sticks': np.random.randint(0, 8, (num_self_actions, 2)),
    }

    tile_features = np.random.randint(0, 2, (num_self_actions, feature_channels, feature_height, feature_width))
    batch.append({
        'action_list': action_list,
        'info': info,
        'tile_features': tile_features,
        'reward': 1000,
        'self_action_mask': self_action_mask,
        'legal_action_mask': legal_action_mask
    })

for i, sample in enumerate(batch):
    action_list = sample['action_list']
    self_action_mask = sample['self_action_mask']
    legal_action_mask = sample['legal_action_mask']
    info = sample['info']
    tile_features = sample['tile_features']
    print("action_list len:", len(action_list))
    print("scores shape:", info['scores'].shape)
    print("self_action_mask shape:", self_action_mask.shape)
    print("legal_action_mask shape:", legal_action_mask.shape)
    print("oya shape:", info['oya'].shape)
    print("honba_riichi_sticks shape:", info['honba_riichi_sticks'].shape)
    print("tile_features shape:", tile_features.shape)
    print("-----------------------------------------------------")
    if i == 3:
        break

device: cuda
action_list len: 45
scores shape: (22, 4)
self_action_mask shape: (45,)
legal_action_mask shape: (22, 49)
oya shape: (22,)
honba_riichi_sticks shape: (22, 2)
tile_features shape: (22, 34, 4, 34)
-----------------------------------------------------
action_list len: 5
scores shape: (2, 4)
self_action_mask shape: (5,)
legal_action_mask shape: (2, 49)
oya shape: (2,)
honba_riichi_sticks shape: (2, 2)
tile_features shape: (2, 34, 4, 34)
-----------------------------------------------------
action_list len: 14
scores shape: (4, 4)
self_action_mask shape: (14,)
legal_action_mask shape: (4, 49)
oya shape: (4,)
honba_riichi_sticks shape: (4, 2)
tile_features shape: (4, 34, 4, 34)
-----------------------------------------------------
action_list len: 87
scores shape: (41, 4)
self_action_mask shape: (87,)
legal_action_mask shape: (41, 49)
oya shape: (41,)
honba_riichi_sticks shape: (41, 2)
tile_features shape: (41, 34, 4, 34)
-----------------------------------------------------


In [2]:
from base_module import myCollator
# 验证collator
collator = myCollator()
data = collator(batch)

In [3]:
print("action_list shape:", data['action_list'].shape)
print("mask shape", data['mask'].shape)
print(data['mask'][:1])
print("self_action_mask shape:", data['self_action_mask'].shape)
# print("self_action_mask:", data['self_action_mask'][:1])
num_self_actions = data['self_action_mask'].sum().item()
print("num_self_actions", num_self_actions)
print("legal_action_mask shape:", data['legal_action_mask'].shape)
print("---------------------------------------------")
print("scores shape:", data['info']['scores'].shape)
print("oya shape:", data['info']['oya'].shape)
print("honba_riichi_sticks shape:", data['info']['honba_riichi_sticks'].shape)
print("---------------------------------------------")
print("tile_features shape:", data['tile_features'].shape)
print("Q shape:", data['Q'].shape)
print("Q:", data['Q'][:50])


action_list shape: torch.Size([32, 128])
mask shape torch.Size([32, 128])
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.]], device='cuda:0')
self_action_mask shape: torch.Size([32, 128])
num_self_actions 1204
legal_action_mask shape: torch.Size([1204, 49])
---------------------------------------------
scores shape: torch.Size([1204, 4])
oya shape: torch.Size([1204])
honba_riichi_sticks shape: torch.Size([1204, 2])
---------------------------------------------


In [4]:
from base_module import *
import torch
import torch.nn as nn
import numpy as np
from transformers import GPT2Model, GPT2Config



# 创建模型
config = GPT2Config(n_embd=512, n_layer=8, n_head=8, n_positions=128)
model = Policy_Network(config).to(device)


In [5]:
model.eval()
output = model.forward(**data)
print("action_logits shape:", output['action_logits'].shape)
# print(output['action_logits'][0])
print("action_probs shape:", output['action_probs'].shape)
# print(output['action_probs'][0])
print("loss :", output['loss'])

action_logits shape: torch.Size([1204, 49])
action_probs shape: torch.Size([1204, 49])
loss : tensor(2976726.5000, device='cuda:0', grad_fn=<NegBackward0>)


# 2.验证麻将环境

In [3]:
from pymahjong import *
from pymahjong import MahjongPyWrapper as pm

table = pm.Table()
table.game_init()


#### 执行动作

In [51]:
who = table.who_make_selection()
print(f"player{who} is making selection")
phase = table.get_phase()
if phase < 4:

    self_action = table.get_self_actions()
    print("avaliable self actions:")
    for action in self_action:
        # print("action:", action.action, "tile:", action.correspond_tiles[0].id)
        print(action.to_string())
elif phase < 16:
    actions = table.get_response_actions()
    print("avaliable response actions:")
    for action in actions:
        print(action.to_string())

elif phase == 16:
    print("game over")

else:
    print("error")


player2 is making selection
avaliable self actions:
Discard 2m
Discard 3m
Discard 6m
Discard 1p
Discard 4p
Discard 7p
Discard 7p
Discard 9p
Discard 1s
Discard 2s
Discard 7s
Discard 2z
Discard 5z
Discard 7z


In [50]:
table.make_selection(0)

#### 查看phase和状态

In [46]:
table.players[1].get_river().to_string()

'6s2h '

In [39]:
phase = table.get_phase()
last_action = table.last_action
honba = table.honba
oya = table.oya
riichi_sticks = table.riichibo
dora = table.get_dora()
yama = table.yama

print("phase:", phase)
print("last_action:", last_action)
print("honba:", honba)
print("oya:", oya)
print("riichi_sticks:", riichi_sticks)
print("dora:", dora[0])
print("yama left:", table.get_remain_tile())
for tile in yama:
    print("id:", tile.id, "tile:", tile.tile, "is_red:", tile.red_dora)




phase: 4
last_action: BaseAction.Discard
honba: 0
oya: 0
riichi_sticks: 0
dora: BaseTile._6s
yama left: 68
id: 25 tile: BaseTile._7m is_red: False
id: 94 tile: BaseTile._6s is_red: False
id: 56 tile: BaseTile._6p is_red: False
id: 33 tile: BaseTile._9m is_red: False
id: 17 tile: BaseTile._5m is_red: False
id: 91 tile: BaseTile._5s is_red: False
id: 78 tile: BaseTile._2s is_red: False
id: 47 tile: BaseTile._3p is_red: False
id: 103 tile: BaseTile._8s is_red: False
id: 120 tile: BaseTile.north is_red: False
id: 27 tile: BaseTile._7m is_red: False
id: 2 tile: BaseTile._1m is_red: False
id: 111 tile: BaseTile.east is_red: False
id: 93 tile: BaseTile._6s is_red: False
id: 68 tile: BaseTile._9p is_red: False
id: 28 tile: BaseTile._8m is_red: False
id: 55 tile: BaseTile._5p is_red: False
id: 77 tile: BaseTile._2s is_red: False
id: 60 tile: BaseTile._7p is_red: False
id: 112 tile: BaseTile.south is_red: False
id: 109 tile: BaseTile.east is_red: False
id: 107 tile: BaseTile._9s is_red: False
id