# 1.验证模型结构

In [None]:
import torch
import numpy as np
from base_module import *

from pymahjong import *
from pymahjong import MahjongPyWrapper as pm
from pymahjong.myEnv_pymahjong import myMahjongEnv

env = myMahjongEnv()
env.reset()


# 检测GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

# 测试用输入
batch =  []
batch_size = 16
feature_channels = 35
feature_height = 4
feature_width = 9

action_space = 47

while len(batch) < batch_size:
    while not env.is_over():
        curr_pid = env.get_curr_player_id()
        valid_actions = env.get_valid_actions()
        action = np.random.choice(valid_actions)
        env.step(player_id=curr_pid, action=action)
    print(env.get_payoffs())
    for i in range(4):
        batch.append(env.get_observation_with_return(i))
    env.reset()
collator = myCollator()

In [None]:
from base_module import *
import torch
import torch.nn as nn
import numpy as np
from transformers import GPT2Model, GPT2Config


# 检测GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)
# 创建模型
config = GPT2Config(n_embd=512, n_layer=8, n_head=8, n_positions=128)
model = Policy_Network(config).to(device)
data = collator(batch)
model.eval()


In [None]:

output = model.forward(**data)
print("action_logits:", output["action_logits"].shape)
print(output["action_logits"][0])
print("probs:", output["action_probs"][0])
# print("loss:", output["loss"])
print("action:", output["action"])


# 2.验证麻将环境

In [1]:
from pymahjong import *
from pymahjong import MahjongPyWrapper as pm
from pymahjong.myEnv_pymahjong import myMahjongEnv

env = myMahjongEnv()
env.reset()

In [None]:
print(env.get_valid_actions())
print(env.legal_actions_mask_record)

In [None]:
while not env.is_over():
    curr_pid = env.get_curr_player_id()
    valid_actions = env.get_valid_actions()
    action = np.random.choice(valid_actions)
    env.step(player_id=curr_pid, action=action)
print(env.get_payoffs())

#### 展示当前phase和合法动作列表

In [None]:
phase, aviable_action = env._proceed()
print("phase:", phase)
for idx, action in enumerate(aviable_action):
    print(idx, action.to_string())

valid_action = env.get_valid_actions()
print("valid_action:", valid_action)

#### 执行动作

In [4]:
env.step(2, 0)

In [None]:
for i in range(4):
    print("player", i, "hand:", env._get_hand_tiles(i))

In [None]:
obs = env.get_observation(1)
print("tiles_features shape:", obs['tiles_features'].shape)
print("oya shape:", obs['info']['oya'].shape)
print("riichi_sticks shape:", obs['info']['riichi_sticks'].shape)
print("action_list shape:", obs['action_list'].shape)
print("action_list:", obs['action_list'])
print("self_action_mask shape:", obs['self_action_mask'].shape)
print("self_action_mask:", obs['self_action_mask'])
print("sum self_action_mask:", obs['self_action_mask'].sum())
print("attention_mask shape:", obs['attention_mask'].shape)
print("attention_mask:", obs['attention_mask'])
print("Q shape:", obs['Q_values'].shape)
print("Q:", obs['Q_values'])
print("legal_action_mask shape:", obs['legal_action_mask'].shape)
print("legal_action_mask:", obs['legal_action_mask'])

In [None]:
for i in range(4):
    print("player", i, "fuuros:", env._get_fuuros(i))
    print("points:", env._get_points(i))

In [None]:
obs['info']

# 3.使用模型来决策

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch

# 使用第一个可用的 GPU，即设备 1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [None]:
import torch
import numpy as np
from base_module import *

from pymahjong import *
from pymahjong import MahjongPyWrapper as pm
from pymahjong.myEnv_pymahjong import myMahjongEnv

env = myMahjongEnv()
env.reset()

config = GPT2Config(n_embd=512, n_layer=8, n_head=8, n_positions=128)
model = Policy_Network(config).to(device)
model.eval()

collator = myCollator()
for i in range(100):
    while not env.is_over():
        curr_pid = env.get_curr_player_id()
        obs = env.get_observation(curr_pid)
        input = {
        "tiles_features": torch.tensor(obs['tiles_features'], dtype=torch.float32).to(device),
        "oya": torch.tensor(obs['info']['oya'], dtype=torch.float32).unsqueeze(0).to(device),
        "riichi_sticks": torch.tensor(obs['info']['riichi_sticks'],dtype=torch.float32).unsqueeze(0).to(device),
        "action_list": torch.tensor(obs['action_list'],dtype=torch.long).unsqueeze(0).to(device),
        "attention_mask": torch.tensor(obs['attention_mask'],dtype=torch.long).unsqueeze(0).to(device),
        "legal_action_mask": torch.tensor(obs['legal_action_mask'], dtype=bool).to(device)
        }
        output = model.inference(**input)
        action = output["action"].item()
        env.step(player_id=curr_pid, action_idx=action)
        # print("player:", curr_pid, "action:", action)
    print(env.get_payoffs())
    env.reset()



# print("action:", output["action"])

# 4. 开始训练

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
import torch

# 使用第一个可用的 GPU，即设备 1
inference_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
training_device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

print("inference_device:", inference_device)
print("training_device:", training_device)

import numpy as np
from base_module import *

from pymahjong import *
from pymahjong import MahjongPyWrapper as pm
from pymahjong.myEnv_pymahjong import myMahjongEnv

env = myMahjongEnv()
env.reset()

config = GPT2Config(n_embd=512, n_layer=8, n_head=8, n_positions=128)
inference_model = Policy_Network(config).to(inference_device)
inference_model.eval()

training_model = Policy_Network(config).to(training_device)
training_model.load_state_dict(inference_model.state_dict())
training_model.train()

collator = myCollator(device=training_device)
inference_collator = inference_Collator(device=inference_device)
optimizer = torch.optim.Adam(training_model.parameters(), lr=1e-4)

inference_device: cuda:0
training_device: cuda:1


In [2]:
from tqdm import tqdm
import random

epochs = 1
num_games = 4
train_epochs = 1

for epoch in tqdm(range(epochs)): 
   # 生成num_games局对局数据
    dataset = []
    while len(dataset) < num_games*4:
        while not env.is_over():
            curr_pid = env.get_curr_player_id()
            obs = env.get_observation(curr_pid)
            input = inference_collator(obs)
            with torch.no_grad():
                output = inference_model.inference(**input)
                action = output["action"].item()
            env.step(player_id=curr_pid, action_idx=action)
        payoffs = env.get_payoffs()
        if not (payoffs == [0, 0, 0, 0]).all():
            for player_id in range(4):
                dataset.append(env.get_observation_with_return(player_id))
            print(payoffs)
            # break

        env.reset()

    # 从dataset中每次取出batch_size个数据
    
    batch_size = 4

    for train_epoch in tqdm(range(train_epochs)):
        # 打乱数据
        random.shuffle(dataset)
        for batch_idx in range(0, len(dataset), batch_size):
            batch = dataset[batch_idx:batch_idx+batch_size]
            data = collator(batch)
            # 移动数据到training_device
            for key in data.keys():
                if key == "info":
                    for sub_key in data[key].keys():
                        data[key][sub_key] = data[key][sub_key].to(training_device)
                else:
                    data[key] = data[key].to(training_device)
            optimizer.zero_grad()
            output = training_model.forward(**data)
            loss = output["loss"]
            loss.backward()
            optimizer.step()
            # print("loss:", loss.item())

    # 将training_model的参数复制到inference_model,并保存checkpoint
    inference_model.load_state_dict(training_model.state_dict())
    torch.save(inference_model.state_dict(), f"./checkpoints/epoch_{epoch}.pth")
    # print(len(dataset))






  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
100%|██████████| 1/1 [00:42<00:00, 42.08s/it]


In [None]:
env.get_valid_actions()

In [11]:
env.step(env.get_curr_player_id(), 46)

### 4.1 检查模型效果

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import torch

# 使用第一个可用的 GPU，即设备 1
inference_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# training_device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

print("inference_device:", inference_device)
# print("training_device:", training_device)

import numpy as np
from base_module import *

from pymahjong import *
from pymahjong import MahjongPyWrapper as pm
from pymahjong.myEnv_pymahjong import myMahjongEnv

env = myMahjongEnv()
env.reset()

config = GPT2Config(n_embd=512, n_layer=8, n_head=8, n_positions=128)
inference_model = Policy_Network(config).to(inference_device)
inference_collator = inference_Collator(device=inference_device)

inference_model.load_state_dict(torch.load("./checkpoints/epoch_2.pth", weights_only=True))
inference_model.eval()


inference_device: cuda:0


Policy_Network(
  (my_model): my_model(
    (tiles_cnn): Tiles_CNN(
      (conv1): Conv2d(35, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (resnet_list): ModuleList(
        (0-17): 18 x Rsidual_Block(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (squential): Sequential(
            (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(12

In [24]:
curr_pid = env.get_curr_player_id()
valid_actions = env.get_valid_actions()
obs = env.get_observation(curr_pid)
input = inference_collator(obs)
output = inference_model.inference(**input)
action = output["action"].item()

phase = env.t.get_phase()
if phase < 4:
    aviable_action = env.t.get_self_actions()
elif phase < 16:
    aviable_action = env.t.get_response_actions()
print("aviable_action:")
ls = []
for idx, av_action in enumerate(aviable_action):
    ls.append(av_action.to_string())

print(list(set(ls)))

print(env.t.players[curr_pid].to_string())
print("curr_pid:", curr_pid)
print("valid_actions:", valid_actions)
print("model selected action:", action, "prob:", output["action_probs"][0][action].item()*100, "%")

env.step(player_id=curr_pid, action_idx=action)

aviable_action:
['Discard 9m', 'Discard 8s', 'Discard 2m', 'Discard 4p', 'Discard 3p', 'Discard 2p', 'Discard 3s', 'Discard 4s']
Pt: 25000
Wind: North
Hand: 2m 9m 9m 2p 2p 3p 4p 3s 4s 4s 8s 
Calls: 8m(8m)8m 
River: 2z4 3m8h 1s12h 5p16- 
Riichi: No
Menzen: No
curr_pid: 0
valid_actions: [1, 8, 10, 11, 12, 20, 21, 25]
model selected action: 8 prob: 12.632443010807037 %


# 5.检查环境bug

In [None]:
import numpy as np
from base_module import *

from pymahjong import *
from pymahjong import MahjongPyWrapper as pm
from pymahjong.myEnv_pymahjong import myMahjongEnv

env = myMahjongEnv()

flag = True
# for i in range(1):
while flag:
    env.reset()
    while not env.is_over():
        curr_pid = env.get_curr_player_id()
        phase = env.t.get_phase()
        if phase < 4:
            aviable_action = env.t.get_self_actions()
        elif phase < 16:
            aviable_action = env.t.get_response_actions()


        ls = []
        for action in aviable_action:
            ls.append(action.action.name)
        set_ls = set(ls)
        if "Riichi" in set_ls:
            flag = False
            print("phase:", phase)
            print("aviable_action:", set_ls)
            break

        valid_actions = env.get_valid_actions()
        action = np.random.choice(valid_actions)
        env.step(player_id=curr_pid, action_idx=action)

    # for i in range(4):
    #     obs = env.get_observation_with_return(i)
    #     print("---------------------------""player:", i, "---------------------------")
    #     print(obs)






In [None]:
import numpy as np
from base_module import *

from pymahjong import *
from pymahjong import MahjongPyWrapper as pm
from pymahjong.myEnv_pymahjong import myMahjongEnv

env = myMahjongEnv()
env.reset()

In [None]:
print("phase:", env.t.get_phase())
print("curr_pid:", env.get_curr_player_id())
phase = env.t.get_phase()
if phase < 4:
    aviable_action = env.t.get_self_actions()
elif phase < 16:
    aviable_action = env.t.get_response_actions()
print("aviable_action:")
ls = []
for idx, action in enumerate(aviable_action):
    ls.append(action.to_string())
    print(idx, action.to_string())
    # print(action.action.name)

print("aviable_action:", ls)

In [None]:
env.get_valid_actions()

In [39]:
env.step(env.get_curr_player_id(), 29)

In [None]:
print(env.riichi_stage2)
print(env.pass_riichi)

In [None]:
env.action_record

In [None]:
print(env.t.to_string())

### 5.1 手操检查

In [None]:
env.t.make_selection(7)
phase = env.t.get_phase()
if phase < 4:
    aviable_action = env.t.get_self_actions()
elif phase < 16:
    aviable_action = env.t.get_response_actions()
print("after phase:", phase)
print("after aviable_action:")
for idx, action in enumerate(aviable_action):
    print(idx, action.to_string())


In [None]:
env.t.make_selection(0)
phase = env.t.get_phase()
if phase < 4:
    aviable_action = env.t.get_self_actions()
elif phase < 16:
    aviable_action = env.t.get_response_actions()
print("after phase:", phase)
print("after aviable_action:")
for idx, action in enumerate(aviable_action):
    print(idx, action.to_string())