In [46]:
import numpy as np

In [47]:
from src.processing.features.zyskowski import ZyskowskiFeatureExtractor
from src.processing.process import process
from src.model.data.dataset import DuelDataset

train_data = process(path="../../../res/train", extractor=ZyskowskiFeatureExtractor())
train_dataset = DuelDataset(features=train_data[0], labels=train_data[1])

print(train_dataset.x_temporal.shape)
print(train_dataset.x_snapshot.shape)
print(train_dataset.y.shape)

torch.Size([130, 14, 4])
torch.Size([130, 47])
torch.Size([130])


In [48]:
kernel = np.array([0.1, 0.2, 0.3, 0.5])


def get_xy(dataset: DuelDataset):
    x_temporal = np.sum(dataset.x_temporal.numpy() * kernel, axis=2)
    x_snapshot = dataset.x_snapshot.numpy()
    x = np.concatenate((x_temporal, x_snapshot), axis=1)
    y = dataset.y.numpy()
    return x, y


x, y = get_xy(train_dataset)
print(x.shape, y.shape)

(130, 61) (130,)


In [49]:
from sklearn import preprocessing

scaler = preprocessing.StandardScaler()
x_scaled = scaler.fit_transform(x)

In [50]:
from sklearn.linear_model import LogisticRegression

logreg = LogisticRegression(random_state=0)
logreg.fit(x_scaled, y)

In [51]:
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

y_pred = logreg.predict(x_scaled)
print(accuracy_score(y, y_pred))
print(f1_score(y, y_pred))

tn, fp, fn, tp = confusion_matrix(y, y_pred).ravel()
print(tn, fp, fn, tp)

0.8461538461538461
0.863013698630137
47 11 9 63


In [53]:
print(logreg.coef_)


[[-0.0951648  -0.49532561  0.48543251  0.26563565 -0.03295261 -0.31016322
   0.42344976  0.14980793 -0.30223439  0.20812419 -0.05650021 -0.172256
  -0.8141138   1.1207008  -0.03108254  0.10758362 -0.80261813 -0.23988635
  -0.13949292  1.01241969  0.30045102  0.39119244 -0.26389388  0.09795354
   0.54883365 -0.0461392  -0.40020512  0.02775397  0.16039075 -0.03547429
  -0.04950099  0.19297656 -0.28893582 -0.40396812  0.14456455  0.12922684
  -0.18074992  0.21182124 -0.31147594 -0.63956816  0.48803489 -0.05517205
   0.53481    -0.27183765 -0.20368349  0.58971393  0.          0.06580895
  -0.07033787  0.66936331 -0.23652142  0.14621135  1.1126754  -0.06789896
   0.17929586 -0.08578714 -0.46127609 -0.05618193 -0.11770297 -0.57803011
  -0.11340495]]
