In [19]:
import pandas as pd
from csgo_wp.data_transform import CSGODataset, transform_multichannel
from csgo_wp.model import LR_CNN
from sklearn.metrics import log_loss, roc_auc_score, accuracy_score
import torch

import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

def test(model, loader, device):
    model.eval()
    model.to(device)

    targets = []
    outputs = []

    with torch.no_grad():
        for index, (data, target) in enumerate(loader):
            targets.append(target)

            # permute the data - just the 2nd and 4th T players
            data[:, 0] = data[:, 0, [0, 3, 2, 1, 4], :]
            data[:, 0] = data[:, 0, :, [0, 3, 2, 1, 4]]

            data[:, 2] = data[:, 2, [0, 3, 2, 1, 4], :]
            data[:, 3] = data[:, 3, :, [0, 3, 2, 1, 4]]

            data[:, 4] = data[:, 4, :, [0, 3, 2, 1, 4]]
            data[:, 4] = data[:, 4, [0, 3, 2, 1, 4], :]

            data = data.to(device)
            output = model(data)
            outputs.append(output)

        y_pred = torch.cat(outputs, dim=0).cpu().numpy().astype(float)
        y_true = torch.cat(targets, dim=0).cpu().numpy().astype(float)

        print('\n' + '-' * 30)
        print('Results')
        print(f'Accuracy: {accuracy_score(y_true, y_pred > 0.5):.4f}')
        print(f'AUC: {roc_auc_score(y_true, y_pred):.4f}')
        print(f'Log loss: {log_loss(y_true, y_pred):.4f}')

    return

test_dataset = CSGODataset(transform=transform_multichannel,
                           dataset_split='test',
                           verbose=False,
                           )

test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=64,
                                          shuffle=False,
                                          num_workers=0,
                                          )

device = 'cuda:0'

model = LR_CNN(input_size=(6, 5, 5),
               hidden_sizes=[200, 100, 50],
               activation='LeakyReLU',
               activation_params={},
               dropout=False,
               batch_norm=False,
               cnn_options=((4, 6, 1, 1, 0, 1, 1, 0),
                            (6, 6, 1, 1, 0, 1, 1, 0),
                            (6, 6, 5, 1, 0, 1, 1, 0),),
               )
model.load_state_dict(torch.load('csgo_wp/model-final.pt'))
model.eval();

Reading transformed data...

Done!


In [20]:
# permuted test data
test(model=model,
     loader=test_loader,
     device=device,
     )


------------------------------
Results
Accuracy: 0.6963
AUC: 0.8290
Log loss: 0.5143
