In [1]:
import pandas as pd
import numpy as np
from pqd_dataset import PQDDataset
from sklearn.utils.class_weight import compute_class_weight
import torch
from torch.utils.data import DataLoader
from pqd_trainer import train_model, evaluate_model
from models.conv_mlp import CONV_MLP

In [2]:
window_size = 32
stride = 1
target_mode = True
epochs=300
batch_size=64
lr = 0.0005

In [3]:
train_df = pd.read_csv('train.csv')
train_dataset = PQDDataset(
                            df=train_df,
                            window_size=window_size,
                            stride=stride,
                            target_mode=target_mode
                            )

Creating window slices:   0%|          | 0/66 [00:00<?, ?it/s]

In [4]:
weights = compute_class_weight("balanced", classes=np.unique(train_dataset.df.target), y=train_dataset.df.target)
weights = torch.FloatTensor(weights)

In [5]:
test_df = pd.read_csv('test.csv')
test_dataset = PQDDataset(
                            df=test_df,
                            window_size=window_size,
                            stride=stride,
                            target_mode=target_mode
                            )

Creating window slices:   0%|          | 0/13 [00:00<?, ?it/s]

In [6]:
model = CONV_MLP(frame_size=32)

In [7]:
train_model(
            model=model,
            dataset=train_dataset,
            eval_dataset=test_dataset,
            epochs=epochs,
            batch_size=batch_size,
            lr=lr,
            weights=None)

Epoch ...:   0%|          | 0/300 [00:00<?, ?it/s]

Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.8004129  0.5392073  0.93370474 0.22888283]
0.6085143866197059


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.86273473 0.78072884 0.94033722 0.34266886]
0.7366941172413912


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.85012807 0.70824785 0.95234529 0.47665056]
0.7580230209101093


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.86264262 0.74055281 0.95891416 0.48698885]
0.7570065819286672


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.85094013 0.70647042 0.95726655 0.50294695]
0.7416753877730978


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.84455895 0.6783307  0.95931359 0.54545455]
0.7442737821530778


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.86537123 0.75190381 0.96483435 0.58867925]
0.7848813559870069


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.87980492 0.78713389 0.97013142 0.58267717]
0.7915937453660105


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.87376049 0.7682455  0.96523966 0.55531453]
0.7659112973211148


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.89275775 0.82108218 0.96758835 0.54014599]
0.7693630025963558


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.8778689  0.77643625 0.9701698  0.64052288]
0.7881257805092481


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.89370363 0.82938448 0.97166142 0.68729642]
0.8647616857622922


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.89331418 0.83067253 0.96507524 0.59348199]
0.8272062872941692


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.91421272 0.87562664 0.97025886 0.67088608]
0.8363245918409972


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.89902964 0.83691309 0.96589435 0.62807018]
0.8367638244608264


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.9183217  0.88821114 0.96489137 0.54965358]
0.8026993838408375


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.89153646 0.82886961 0.96174057 0.53043478]
0.7792788882581465


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.90736563 0.85470085 0.97082569 0.59130435]
0.8068401952398861


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.90452593 0.85218034 0.96806661 0.543379  ]
0.7886211681132528


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.90872884 0.8614972  0.96956803 0.5904059 ]
0.8295451639055021


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.9043741  0.84877807 0.97169724 0.65020576]
0.8242774321842767


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.90770858 0.85192498 0.97086141 0.39215686]
0.7536124784190907


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.91120108 0.87169767 0.97157671 0.66666667]
0.8592729758032425


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.89387941 0.81698427 0.97072362 0.53759398]
0.7977006996273802


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.90614973 0.86238975 0.96709238 0.63109756]
0.8714965518320505


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

Step ...:   0%|          | 0/271 [00:00<?, ?it/s]

[0.91379195 0.87855173 0.97434473 0.72694394]
0.8758275353825914


Step ...:   0%|          | 0/974 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
evaluate_model(model, test_dataset)