In [1]:
import scipy

In [2]:
# data = scipy.io.loadmat("offline_data/20M/mahjong-offline-data-batch-0.mat")

In [3]:
# print(f"Executor obs shape: {data['X'].shape}\n"
#       f"Oracle obs shape: {data['O'].shape}\n"
#       f"Action selected shape: {data['A'].shape}\n"
#       f"Valid action shape: {data['M'].shape}\n"
#       f"Reward shape: {data['R'].shape}\n"
#       f"Done signal shape: {data['D'].shape}\n"
#       f"Last step shape: {data['V'].shape}")
# del data

In [4]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torch import tensor
import torch


class RiichiDataset(Dataset):
    def __init__(self, root_path='offline_data/20M', index=0, oracle=False):
        file_path = f"{root_path}/mahjong-offline-data-batch-{index}.mat"
        raw_data = scipy.io.loadmat(file_path)
        action_array = np.reshape(raw_data['A'], -1)
        action_data = raw_data['A'].T
        if oracle:
            obs_data = np.concatenate([raw_data['X'], raw_data['O']], axis=1)
        else:
            obs_data = raw_data['X']
        # 41 立直 | 46 不立直
        riichi_mask = np.bitwise_or(action_array == 41, action_array == 46)
        self.obs_data_masked = obs_data[riichi_mask]
        self.action_data_masked = np.where(action_data[riichi_mask] == 41, 1, 0)

    def __len__(self):
        return len(self.action_data_masked)

    def __getitem__(self, idx):
        return tensor(self.obs_data_masked[idx], dtype=torch.float), tensor(self.action_data_masked[idx],
                                                                            dtype=torch.float)

In [5]:
riichi_dataset_valid = RiichiDataset(index=39)

In [6]:
from torch import nn


class RiichiModel(nn.Module):
    def __init__(self, in_channels=93):
        super().__init__()
        self.input_layer = nn.Sequential(nn.Conv1d(in_channels, 256, 3, padding=1),
                                         nn.ReLU())

        self.hidden_layer = nn.Sequential(nn.Conv1d(256, 256, 3, padding=1),
                                          nn.ReLU(),
                                          nn.Conv1d(256, 32, 3, padding=1),
                                          nn.ReLU(),
                                          nn.Flatten(),
                                          nn.Linear(32 * 34, 1024),
                                          nn.ReLU(),
                                          nn.Linear(1024, 256),
                                          nn.ReLU()
                                          )

        self.output_layer = nn.Sequential(nn.Linear(256, 1),
                                          nn.Sigmoid())

    def forward(self, obs):
        return self.output_layer(self.hidden_layer(self.input_layer(obs)))

In [7]:
riichi_model = RiichiModel()

In [8]:
def train(model: nn.Module, train_dataset: Dataset, test_dataset: Dataset, loss_fn: nn.Module,
          optimizer: torch.optim.Optimizer,
          epoch: int = 1, train_batch_size=10, eval_interval=1000):
    for e in range(epoch):
        dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
        for iter_time, (obs, target_action) in enumerate(dataloader):
            model.train()
            action = model(obs)
            loss = loss_fn(action, target_action)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if iter_time % eval_interval == 0:
                evaluate(model, train_dataset, test_dataset, loss_fn, f"# epoch_{e}_iter_{iter_time} #")

In [9]:
def evaluate(model: nn.Module, train_dataset: Dataset, test_dataset: Dataset, loss_fn: nn.Module, log_title="",
             test_batch_size=10):
    model.eval()
    train_dataloader = DataLoader(train_dataset, batch_size=test_batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True)
    train_total_loss = []
    test_total_loss = []

    train_accuracy_total = 0
    train_accuracy = 0
    train_precision_total = 0
    train_precision = 0
    train_recall_total = 0
    train_recall = 0

    test_accuracy_total = 0
    test_accuracy = 0
    test_precision_total = 0
    test_precision = 0
    test_recall_total = 0
    test_recall = 0

    for obs, target_action in train_dataloader:
        action = model(obs)
        loss = loss_fn(action, target_action)
        train_total_loss.append(loss.item())

        action_bool = np.where(action > 0.5, True, False)
        target_action_bool = np.where(target_action > 0.5, True, False)

        train_accuracy_total += len(action_bool)
        train_accuracy += np.sum(action_bool == target_action_bool)
        train_precision_total += np.sum(action_bool)
        train_precision += np.sum(target_action_bool[action_bool])
        train_recall_total += np.sum(target_action_bool)
        train_recall += np.sum(action_bool[target_action_bool])

    for obs, target_action in test_dataloader:
        action = model(obs)
        loss = loss_fn(action, target_action)
        test_total_loss.append(loss.item())

        action_bool = np.where(action > 0.5, True, False)
        target_action_bool = np.where(target_action > 0.5, True, False)

        test_accuracy_total += len(action_bool)
        test_accuracy += np.sum(action_bool == target_action_bool)
        test_precision_total += np.sum(action_bool)
        test_precision += np.sum(target_action_bool[action_bool])
        test_recall_total += np.sum(target_action_bool)
        test_recall += np.sum(action_bool[target_action_bool])

    print(f"{log_title}\n"
          f"Train Loss: {np.mean(train_total_loss)} | Test Loss: {np.mean(test_total_loss)}\n"
          f"Train Accuracy: {train_accuracy / train_accuracy_total} | Test Accuracy: {test_accuracy / test_accuracy_total}\n"
          f"Train Precision: {train_precision / train_precision_total} | Test Precision: {test_precision / test_precision_total}\n"
          f"Train Recall: {train_recall / train_recall_total} | Test Recall: {test_recall / test_recall_total}"
          )

In [10]:
from torch.optim import Adam

adam_optimizer = Adam(riichi_model.parameters(), lr=0.0005)

for i in range(39):
    train(riichi_model, RiichiDataset(index=i), riichi_dataset_valid, nn.MSELoss(), adam_optimizer, epoch=1)

  f"Train Precision: {train_precision / train_precision_total} | Test Precision: {test_precision / test_precision_total}\n"


# epoch_0_iter_0 #
Train Loss: 0.24811898497038837 | Test Loss: 0.24807925103267947
Train Accuracy: 0.5494988434849652 | Test Accuracy: 0.5503870028354663
Train Precision: nan | Test Precision: nan
Train Recall: 0.0 | Test Recall: 0.0
# epoch_0_iter_1000 #
Train Loss: 0.16202351607169752 | Test Loss: 0.16459049701412348
Train Accuracy: 0.7612181958365459 | Test Accuracy: 0.7537742355736071
Train Precision: 0.6778036778036778 | Test Precision: 0.6709170530654301
Train Recall: 0.895772719493411 | Test Recall: 0.8878472814044657
# epoch_0_iter_0 #
Train Loss: 0.16802529196627577 | Test Loss: 0.1652248991339314
Train Accuracy: 0.7482665058788062 | Test Accuracy: 0.7562265307686413
Train Precision: 0.676350191469695 | Test Precision: 0.6834699453551912
Train Recall: 0.8521044751289303 | Test Recall: 0.8527356400204534
# epoch_0_iter_1000 #
Train Loss: 0.18640646765067523 | Test Loss: 0.1877184095150181
Train Accuracy: 0.7408049442267108 | Test Accuracy: 0.7412062226990574
Train Precision: 0

In [12]:
evaluate(riichi_model, RiichiDataset(), riichi_dataset_valid, nn.MSELoss())


Train Loss: 0.14119381422490032 | Test Loss: 0.14254177743390603
Train Accuracy: 0.7984579799537393 | Test Accuracy: 0.7956931565637213
Train Precision: 0.7402172295789317 | Test Precision: 0.7311191335740073
Train Recall: 0.8514461749101488 | Test Recall: 0.8629623316856997


In [13]:
torch.save(riichi_model.state_dict(), "weights/riichi_model_0.795a.pt")