In [1]:
import scipy

In [20]:
# import numpy as np
#
# file_path = "offline_data/20M/mahjong-offline-data-batch-0.mat"
# raw_data = scipy.io.loadmat(file_path)
# action_array = np.reshape(raw_data['A'], -1)
# action_data = raw_data['A'].T
# obs_data = raw_data['X']
# mask = np.bitwise_and(action_array >= 34, action_array <= 45)
# mask = np.bitwise_and(mask, action_array != 42)
# mask = np.bitwise_and(mask, action_array != 43)
# obs_data_masked = obs_data[mask]
# action_data_masked = action_data[mask]
#
# import pandas as pd
#
# pd.DataFrame(action_data_masked).groupby(0).value_counts()

0
34     2075
35     2253
36     2052
37     8700
38      381
39       42
40      243
41     5843
44       91
45    79779
dtype: int64

In [2]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torch import tensor
import torch


class PongDataset(Dataset):
    def __init__(self, root_path='offline_data/20M', index=0, oracle=False):
        file_path = f"{root_path}/mahjong-offline-data-batch-{index}.mat"
        raw_data = scipy.io.loadmat(file_path)
        action_array = np.reshape(raw_data['A'], -1)
        action_data = raw_data['A'].T
        if oracle:
            obs_data = np.concatenate([raw_data['X'], raw_data['O']], axis=1)
        else:
            obs_data = raw_data['X']
        # 37 碰 | 45 不响应
        # 样本不平衡
        pong_mask = np.bitwise_or(action_array == 37, action_array == 45)
        self.obs_data_masked = obs_data[pong_mask]
        self.action_data_masked = np.where(action_data[pong_mask] == 37, 1, 0)

    def __len__(self):
        return len(self.action_data_masked)

    def __getitem__(self, idx):
        return tensor(self.obs_data_masked[idx], dtype=torch.float), tensor(self.action_data_masked[idx],
                                                                            dtype=torch.float)

In [3]:
pong_dataset_valid = PongDataset(index=39)

In [4]:
from torch import nn


class PongModel(nn.Module):
    def __init__(self, in_channels=93):
        super().__init__()
        self.input_layer = nn.Sequential(nn.Conv1d(in_channels, 256, 3, padding=1),
                                         nn.ReLU())

        self.hidden_layer = nn.Sequential(nn.Conv1d(256, 256, 3, padding=1),
                                          nn.ReLU(),
                                          nn.Conv1d(256, 32, 3, padding=1),
                                          nn.ReLU(),
                                          nn.Flatten(),
                                          nn.Linear(32 * 34, 1024),
                                          nn.ReLU(),
                                          nn.Linear(1024, 256),
                                          nn.ReLU()
                                          )

        self.output_layer = nn.Sequential(nn.Linear(256, 1),
                                          nn.Sigmoid())

    def forward(self, obs):
        return self.output_layer(self.hidden_layer(self.input_layer(obs)))

In [5]:
pong_model = PongModel()

In [6]:
def train(model: nn.Module, train_dataset: Dataset, test_dataset: Dataset, loss_fn: nn.Module,
          optimizer: torch.optim.Optimizer,
          epoch: int = 1, train_batch_size=10, eval_interval=1000):
    for e in range(epoch):
        dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
        for iter_time, (obs, target_action) in enumerate(dataloader):
            model.train()
            action = model(obs)
            loss = loss_fn(action, target_action)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if iter_time % eval_interval == 0:
                evaluate(model, train_dataset, test_dataset, loss_fn, f"# epoch_{e}_iter_{iter_time} #")

In [7]:
def evaluate(model: nn.Module, train_dataset: Dataset, test_dataset: Dataset, loss_fn: nn.Module, log_title="",
             test_batch_size=10):
    model.eval()
    train_dataloader = DataLoader(train_dataset, batch_size=test_batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True)
    train_total_loss = []
    test_total_loss = []

    train_accuracy_total = 0
    train_accuracy = 0
    train_precision_total = 0
    train_precision = 0
    train_recall_total = 0
    train_recall = 0

    test_accuracy_total = 0
    test_accuracy = 0
    test_precision_total = 0
    test_precision = 0
    test_recall_total = 0
    test_recall = 0

    for obs, target_action in train_dataloader:
        action = model(obs)
        loss = loss_fn(action, target_action)
        train_total_loss.append(loss.item())

        action_bool = np.where(action > 0.5, True, False)
        target_action_bool = np.where(target_action > 0.5, True, False)

        train_accuracy_total += len(action_bool)
        train_accuracy += np.sum(action_bool == target_action_bool)
        train_precision_total += np.sum(action_bool)
        train_precision += np.sum(target_action_bool[action_bool])
        train_recall_total += np.sum(target_action_bool)
        train_recall += np.sum(action_bool[target_action_bool])

    for obs, target_action in test_dataloader:
        action = model(obs)
        loss = loss_fn(action, target_action)
        test_total_loss.append(loss.item())

        action_bool = np.where(action > 0.5, True, False)
        target_action_bool = np.where(target_action > 0.5, True, False)

        test_accuracy_total += len(action_bool)
        test_accuracy += np.sum(action_bool == target_action_bool)
        test_precision_total += np.sum(action_bool)
        test_precision += np.sum(target_action_bool[action_bool])
        test_recall_total += np.sum(target_action_bool)
        test_recall += np.sum(action_bool[target_action_bool])

    print(f"{log_title}\n"
          f"Train Loss: {np.mean(train_total_loss)} | Test Loss: {np.mean(test_total_loss)}\n"
          f"Train Accuracy: {train_accuracy / train_accuracy_total} | Test Accuracy: {test_accuracy / test_accuracy_total}\n"
          f"Train Precision: {train_precision / train_precision_total} | Test Precision: {test_precision / test_precision_total}\n"
          f"Train Recall: {train_recall / train_recall_total} | Test Recall: {test_recall / test_recall_total}"
          )

In [8]:
from torch.optim import Adam

adam_optimizer = Adam(pong_model.parameters(), lr=0.0005)

for i in range(39):
    train(pong_model, PongDataset(index=i), pong_dataset_valid, nn.MSELoss(), adam_optimizer, epoch=1)

# epoch_0_iter_0 #
Train Loss: 0.2494403034613668 | Test Loss: 0.2494404250708395
Train Accuracy: 0.8492297607341855 | Test Accuracy: 0.8483759424068931
Train Precision: 0.19148936170212766 | Test Precision: 0.1938509640437728
Train Recall: 0.16551724137931034 | Test Recall: 0.17154715240949966


  f"Train Precision: {train_precision / train_precision_total} | Test Precision: {test_precision / test_precision_total}\n"


# epoch_0_iter_1000 #
Train Loss: 0.09832981898690192 | Test Loss: 0.09835619731120898
Train Accuracy: 0.9016715830875123 | Test Accuracy: 0.9016609035768948
Train Precision: nan | Test Precision: nan
Train Recall: 0.0 | Test Recall: 0.0
# epoch_0_iter_2000 #
Train Loss: 0.09832730745916907 | Test Loss: 0.0983335241538545
Train Accuracy: 0.9016715830875123 | Test Accuracy: 0.9016609035768948
Train Precision: nan | Test Precision: nan
Train Recall: 0.0 | Test Recall: 0.0
# epoch_0_iter_3000 #
Train Loss: 0.09832856324492915 | Test Loss: 0.09834486073126478
Train Accuracy: 0.9016715830875123 | Test Accuracy: 0.9016609035768948
Train Precision: nan | Test Precision: nan
Train Recall: 0.0 | Test Recall: 0.0
# epoch_0_iter_4000 #
Train Loss: 0.09833107480466043 | Test Loss: 0.09834486072619693
Train Accuracy: 0.9016715830875123 | Test Accuracy: 0.9016609035768948
Train Precision: nan | Test Precision: nan
Train Recall: 0.0 | Test Recall: 0.0
# epoch_0_iter_5000 #
Train Loss: 0.0983285632567

KeyboardInterrupt: 

In [12]:
evaluate(pong_model, PongDataset(), pong_dataset_valid, nn.MSELoss())


Train Loss: 0.14119381422490032 | Test Loss: 0.14254177743390603
Train Accuracy: 0.7984579799537393 | Test Accuracy: 0.7956931565637213
Train Precision: 0.7402172295789317 | Test Precision: 0.7311191335740073
Train Recall: 0.8514461749101488 | Test Recall: 0.8629623316856997


In [13]:
torch.save(pong_model.state_dict(), "weights/pong_model_0.795a.pt")