In [214]:
from parsing import Parser
import re
import torch
from tqdm import tqdm

In [2]:
parser = Parser('replays_raw')
replays_data = parser.parse_replays()

1052it [07:15,  2.44it/s]

Error processing replays_raw\[sc2rep.ru_1565107965]_1x1_TSGSolar(Z)_llllllllllll(P).SC2Replay: Replacement index 2 out of range for positional args tuple


1720it [13:33,  2.11it/s]


In [234]:
save_path = 'replays_processed/baseline/processed_replays.pt'

def concatenate_rows_single_last_column(tensor):
    sorted_tensor, indices = torch.sort(tensor, dim=0, descending=False)
    last_col_vals = sorted_tensor[:, -1]
    unique_vals, counts = torch.unique(last_col_vals, return_counts=True)
    paired_indices = unique_vals[counts == 2]
    mask = torch.isin(last_col_vals, paired_indices)
    to_concatenate = sorted_tensor[mask]
    to_concatenate = to_concatenate.view(-1, 2, to_concatenate.size(1))
    concatenated = torch.cat((to_concatenate[:, 0, :-1], to_concatenate[:, 1, :]), dim=1)
    return concatenated

full_data = []
for game in tqdm(replays_data):
    game_data = []
    replay = game[0]
    winner_text = game[1]
    winner = 2-int(re.search(r"Team (\d+)", str(winner_text)).group(1))
    for timestep in replay:
        stats = torch.tensor(list(timestep.stats.values()))
        second = torch.tensor(timestep.second).unsqueeze(0)
        full_stats = torch.concat((stats, second), dim=-1)
        game_data.append(full_stats)
    game_data_tensor = torch.stack(game_data)
    game_data_tensor = concatenate_rows_single_last_column(game_data_tensor)
    win_tensor = torch.full((game_data_tensor.size(0), 1), winner)
    game_tensor = torch.cat((game_data_tensor, win_tensor), dim=1)
    full_data.append(game_tensor)
full_data_tensor = torch.concat(full_data, dim=0)
torch.save(full_data_tensor, save_path)

  0%|          | 0/1719 [00:00<?, ?it/s]

100%|██████████| 1719/1719 [00:08<00:00, 198.18it/s]


### Creating the Dataloader (execute from here)

In [235]:
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
from torch.optim import Adam

save_path = 'replays_processed/baseline/processed_replays.pt'
full_data_tensor = torch.load(save_path)
full_data_tensor.shape

torch.Size([169701, 80])

### Making the TensorDataset and Dataloader

In [236]:
batch_size = 512
hidden_size = 64
learning_rate = 1e-3
epochs = 500

features = full_data_tensor[:, :-1]
labels = full_data_tensor[:, -1].long()
dataset = TensorDataset(features, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [237]:
next(iter(dataloader))[0].shape

torch.Size([512, 79])

### Creating the baseline MLP model

In [238]:
class MLPClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.ffwd = nn.Sequential(
            nn.BatchNorm1d(79),
            nn.Linear(79, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.ffwd(x)

model = MLPClassifier()

### Training the model

In [239]:
optimizer = Adam(model.parameters(), lr=learning_rate)  # Using Adam optimizer
bceloss = nn.BCELoss()

In [240]:
for epoch in range(epochs):
    for batch_idx, (x, labels) in enumerate(dataloader):
        optimizer.zero_grad()
        
        out = model(x.float())
        labels = labels.unsqueeze(1).float()
        loss = bceloss(out, labels)

        loss.backward()
        optimizer.step()

        if (int(batch_idx) == 0) & (epoch % 5 == 0):
            print(f"Epoch [{epoch+1}/{epochs}], Batch {batch_idx}, Loss: {loss.item()}")

Epoch [1/500], Batch 0, Loss: 0.7104644775390625
Epoch [6/500], Batch 0, Loss: 0.5893709063529968
Epoch [11/500], Batch 0, Loss: 0.5423356294631958
Epoch [16/500], Batch 0, Loss: 0.5362827777862549
Epoch [21/500], Batch 0, Loss: 0.5354219079017639
Epoch [26/500], Batch 0, Loss: 0.4985294044017792
Epoch [31/500], Batch 0, Loss: 0.5103408694267273
Epoch [36/500], Batch 0, Loss: 0.4838268756866455
Epoch [41/500], Batch 0, Loss: 0.48968470096588135
Epoch [46/500], Batch 0, Loss: 0.4690924286842346
Epoch [51/500], Batch 0, Loss: 0.4963911175727844
Epoch [56/500], Batch 0, Loss: 0.49744921922683716
Epoch [61/500], Batch 0, Loss: 0.4513827860355377
Epoch [66/500], Batch 0, Loss: 0.4415157437324524
Epoch [71/500], Batch 0, Loss: 0.4613882601261139
Epoch [76/500], Batch 0, Loss: 0.4552359879016876
Epoch [81/500], Batch 0, Loss: 0.49122709035873413
Epoch [86/500], Batch 0, Loss: 0.4760972857475281
Epoch [91/500], Batch 0, Loss: 0.4421740174293518
Epoch [96/500], Batch 0, Loss: 0.4339776337146759

In [244]:
import sc2reader
model.eval()

@torch.no_grad
def test_replay():
    game_data = []
    replay = sc2reader.load_replay('test.SC2Replay')
    for event in replay.events:
        if event.name == "PlayerStatsEvent":
            stats = torch.tensor(list(event.stats.values()))
            time = torch.tensor(event.second).unsqueeze(0)
            x = torch.concat((stats, time))
            game_data.append(x)
    game_data_tensor = torch.stack(game_data)
    game_data_tensor = concatenate_rows_single_last_column(game_data_tensor)
    print(f'time:{time}, predicted win rate: {model(game_data_tensor.float())}')

test_replay()

time:tensor([925]), predicted win rate: tensor([[0.5591],
        [0.5779],
        [0.5547],
        [0.5588],
        [0.5522],
        [0.5461],
        [0.5373],
        [0.5350],
        [0.5282],
        [0.5276],
        [0.5214],
        [0.5136],
        [0.5146],
        [0.5066],
        [0.5110],
        [0.5125],
        [0.5127],
        [0.5152],
        [0.5077],
        [0.5092],
        [0.4970],
        [0.4873],
        [0.5340],
        [0.5264],
        [0.5074],
        [0.5081],
        [0.4966],
        [0.4536],
        [0.4038],
        [0.3600],
        [0.4010],
        [0.4298],
        [0.3433],
        [0.3794],
        [0.4067],
        [0.5139],
        [0.4984],
        [0.5696],
        [0.6665],
        [0.6261],
        [0.5215],
        [0.5941],
        [0.5126],
        [0.6159],
        [0.6960],
        [0.7614],
        [0.7909],
        [0.3387],
        [0.2834],
        [0.6304],
        [0.7494],
        [0.8203],
        [0.7824],
      