# 1.验证模型结构

In [None]:
import torch
import numpy as np
from base_module import *

from pymahjong import *
from pymahjong import MahjongPyWrapper as pm
from pymahjong.myEnv_pymahjong import myMahjongEnv

env = myMahjongEnv()
env.reset()


# 检测GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

# 测试用输入
batch =  []
batch_size = 16
feature_channels = 35
feature_height = 4
feature_width = 9

action_space = 47

while len(batch) < batch_size:
    while not env.is_over():
        curr_pid = env.get_curr_player_id()
        valid_actions = env.get_valid_actions()
        action = np.random.choice(valid_actions)
        env.step(player_id=curr_pid, action=action)
    print(env.get_payoffs())
    for i in range(4):
        batch.append(env.get_observation_with_return(i))
    env.reset()
collator = myCollator()

In [None]:
from base_module import *
import torch
import torch.nn as nn
import numpy as np
from transformers import GPT2Model, GPT2Config


# 检测GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)
# 创建模型
config = GPT2Config(n_embd=512, n_layer=8, n_head=8, n_positions=128)
model = Policy_Network(config).to(device)
data = collator(batch)
model.eval()


In [None]:

output = model.forward(**data)
print("action_logits:", output["action_logits"].shape)
print(output["action_logits"][0])
print("probs:", output["action_probs"][0])
# print("loss:", output["loss"])
print("action:", output["action"])


# 2.验证麻将环境

In [1]:
from pymahjong import *
from pymahjong import MahjongPyWrapper as pm
from pymahjong.myEnv_pymahjong import myMahjongEnv

env = myMahjongEnv()
env.reset()

In [None]:
print(env.get_valid_actions())
print(env.legal_actions_mask_record)

In [None]:
while not env.is_over():
    curr_pid = env.get_curr_player_id()
    valid_actions = env.get_valid_actions()
    action = np.random.choice(valid_actions)
    env.step(player_id=curr_pid, action=action)
print(env.get_payoffs())

#### 展示当前phase和合法动作列表

In [None]:
phase, aviable_action = env._proceed()
print("phase:", phase)
for idx, action in enumerate(aviable_action):
    print(idx, action.to_string())

valid_action = env.get_valid_actions()
print("valid_action:", valid_action)

#### 执行动作

In [4]:
env.step(2, 0)

In [None]:
for i in range(4):
    print("player", i, "hand:", env._get_hand_tiles(i))

In [None]:
obs = env.get_observation(1)
print("tiles_features shape:", obs['tiles_features'].shape)
print("oya shape:", obs['info']['oya'].shape)
print("riichi_sticks shape:", obs['info']['riichi_sticks'].shape)
print("action_list shape:", obs['action_list'].shape)
print("action_list:", obs['action_list'])
print("self_action_mask shape:", obs['self_action_mask'].shape)
print("self_action_mask:", obs['self_action_mask'])
print("sum self_action_mask:", obs['self_action_mask'].sum())
print("attention_mask shape:", obs['attention_mask'].shape)
print("attention_mask:", obs['attention_mask'])
print("Q shape:", obs['Q_values'].shape)
print("Q:", obs['Q_values'])
print("legal_action_mask shape:", obs['legal_action_mask'].shape)
print("legal_action_mask:", obs['legal_action_mask'])

In [None]:
for i in range(4):
    print("player", i, "fuuros:", env._get_fuuros(i))
    print("points:", env._get_points(i))

In [None]:
obs['info']

# 3.使用模型来决策

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch

# 使用第一个可用的 GPU，即设备 1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [None]:
import torch
import numpy as np
from base_module import *

from pymahjong import *
from pymahjong import MahjongPyWrapper as pm
from pymahjong.myEnv_pymahjong import myMahjongEnv

env = myMahjongEnv()
env.reset()

config = GPT2Config(n_embd=512, n_layer=8, n_head=8, n_positions=128)
model = Policy_Network(config).to(device)
model.eval()

collator = myCollator()
for i in range(100):
    while not env.is_over():
        curr_pid = env.get_curr_player_id()
        obs = env.get_observation(curr_pid)
        input = {
        "tiles_features": torch.tensor(obs['tiles_features'], dtype=torch.float32).to(device),
        "oya": torch.tensor(obs['info']['oya'], dtype=torch.float32).unsqueeze(0).to(device),
        "riichi_sticks": torch.tensor(obs['info']['riichi_sticks'],dtype=torch.float32).unsqueeze(0).to(device),
        "action_list": torch.tensor(obs['action_list'],dtype=torch.long).unsqueeze(0).to(device),
        "attention_mask": torch.tensor(obs['attention_mask'],dtype=torch.long).unsqueeze(0).to(device),
        "legal_action_mask": torch.tensor(obs['legal_action_mask'], dtype=bool).to(device)
        }
        output = model.inference(**input)
        action = output["action"].item()
        env.step(player_id=curr_pid, action=action)
        # print("player:", curr_pid, "action:", action)
    print(env.get_payoffs())
    env.reset()



# print("action:", output["action"])

# 4. 开始训练

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
import torch

# 使用第一个可用的 GPU，即设备 1
inference_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
training_device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

print("inference_device:", inference_device)
print("training_device:", training_device)

import numpy as np
from base_module import *

from pymahjong import *
from pymahjong import MahjongPyWrapper as pm
from pymahjong.myEnv_pymahjong import myMahjongEnv

env = myMahjongEnv()
env.reset()

config = GPT2Config(n_embd=512, n_layer=8, n_head=8, n_positions=128)
inference_model = Policy_Network(config).to(inference_device)
inference_model.eval()

training_model = Policy_Network(config).to(training_device)
training_model.load_state_dict(inference_model.state_dict())
training_model.train()

collator = myCollator(device=training_device)
optimizer = torch.optim.Adam(training_model.parameters(), lr=1e-4)

inference_device: cuda:0
training_device: cuda:1


In [2]:
epoch = 10
num_games = 10
for i in range(epoch): 
   # 生成10局对局数据
    dataset = []
    for i in range(num_games):
        while not env.is_over():
            curr_pid = env.get_curr_player_id()
            obs = env.get_observation(curr_pid)
            input = {
            "tiles_features": torch.tensor(obs['tiles_features'], dtype=torch.float32).to(inference_device),
            "oya": torch.tensor(obs['info']['oya'], dtype=torch.float32).unsqueeze(0).to(inference_device),
            "riichi_sticks": torch.tensor(obs['info']['riichi_sticks'],dtype=torch.float32).unsqueeze(0).to(inference_device),
            "action_list": torch.tensor(obs['action_list'],dtype=torch.long).unsqueeze(0).to(inference_device),
            "attention_mask": torch.tensor(obs['attention_mask'],dtype=torch.long).unsqueeze(0).to(inference_device),
            "legal_action_mask": torch.tensor(obs['legal_action_mask'], dtype=bool).to(inference_device)
            }
            output = inference_model.inference(**input)
            action = output["action"].item()
            env.step(player_id=curr_pid, action=action)
        for i in range(4):
            dataset.append(env.get_observation_with_return(i))
        print(env.get_payoffs())
        env.reset()

    # 从dataset中每次取出batch_size个数据
    batch_size = 4
    collator = myCollator()
    for i in range(0, len(dataset), batch_size):
        optimizer.zero_grad()
        batch = dataset[i:i+batch_size]
        data = collator(batch)
        # 移动数据到training_device
        for key in data.keys():
            if key == "info":
                for sub_key in data[key].keys():
                    data[key][sub_key] = data[key][sub_key].to(training_device)
            else:
                data[key] = data[key].to(training_device)
        output = training_model.forward(**data)
        loss = output["loss"]
        loss.backward()
        optimizer.step()
        print("loss:", loss.item())

    # 将training_model的参数复制到inference_model,并保存checkpoint
    inference_model.load_state_dict(training_model.state_dict())
    torch.save(inference_model.state_dict(), f"./checkpoints/epoch_{epoch}.pth")






  "legal_action_mask": torch.tensor(obs['legal_action_mask'], dtype=bool).to(inference_device)


[    0. -2000.     0.  2000.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]
[-1000. -1000.  3000. -1000.]
[-1000. -1000. -1000.  3000.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]
loss: -4830.759765625
loss: -0.0
loss: -0.0
loss: -0.0
loss: -0.0
loss: -0.0
loss: -17482.611328125
loss: -2279.8203125
loss: -0.0
loss: -0.0
[ 1500. -1500. -1500.  1500.]
[ 2700.  -700. -1300.  -700.]
-------------- execption in make_selection_from_action_basetile ------------------
Cannot locate action with action = 0
BaseAction.Pass []
[[1 0 0 0 0 0 0 1 1 0 0 1 0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 1 1 0]
 [1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 1 1 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 1 0 0 1 1 1 1 0 1 0 1 1 0 0 0 1 1 1 1 0 0 0 0 0 1]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

SystemError: 

In [11]:
print("phase:", env.t.get_phase())
print("player:", env.get_curr_player_id())
print("legal actions:", env.get_valid_actions())
print("action", action)
action_container = env.act_container
nonzero_index = [i for i in range(len(action_container)) if action_container[i] != 0]
print("nonzero_index:", nonzero_index)

phase: 3
player: 3
legal actions: [ 0  7  8 11 18 22 31 32 45]
action 45
nonzero_index: [0, 7, 8, 11, 18, 22, 31, 32, 41, 45]


In [13]:
env.step(3, 45)

-------------- execption in make_selection_from_action_basetile ------------------
Cannot locate action with action = 0
BaseAction.Pass []
[[1 0 0 0 0 0 0 1 1 0 0 1 0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 1 1 0]
 [1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 1 1 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 1 0 0 1 1 1 1 0 1 0 1 1 0 0 0 1 1 1 1 0 0 0 0 0 1]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 1 1

SystemError: 