In [1]:
import pandas as pd
import numpy as np
from pqd_dataset import PQDDataset
from sklearn.utils.class_weight import compute_class_weight
import torch
from torch.utils.data import DataLoader
from pqd_trainer import train_model, train_fft_model, evaluate_model
from models.conv_mlp import CONV_MLP
import random
import os

In [2]:
def seed_everything(seed: int = 42):
    """
    This function is used to maintain repeatability
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [3]:
seed = 29
seed_everything(seed)

In [4]:
window_size = 32
stride = 1
target_mode = True
epochs=100
batch_size=32
lr = 0.0005

In [5]:
train_df = pd.read_csv('train.csv')
train_dataset = PQDDataset(
                            df=train_df,
                            window_size=window_size,
                            stride=stride,
                            target_mode=target_mode
                            )

Creating window slices:   0%|          | 0/66 [00:00<?, ?it/s]

In [6]:
weights = compute_class_weight("balanced", classes=np.unique(train_dataset.df.target), y=train_dataset.df.target)
weights = torch.FloatTensor(weights)

In [7]:
test_df = pd.read_csv('test.csv')
test_dataset = PQDDataset(
                            df=test_df,
                            window_size=window_size,
                            stride=stride,
                            target_mode=target_mode
                            )

Creating window slices:   0%|          | 0/13 [00:00<?, ?it/s]

In [8]:
model = CONV_MLP(frame_size=32)

In [9]:
train_fft_model(
            model=model,
            dataset=train_dataset,
            run_seed=seed,
            eval_dataset=test_dataset,
            epochs=epochs,
            batch_size=batch_size,
            lr=lr,
            weights=None)

Epoch ...:   0%|          | 0/100 [00:00<?, ?it/s]

Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.18482165326677727
6.229449271513206


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.97823031 0.93154287 0.30923695]
epoch 0  0.7396700409425007


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.088654726255447
1.5530527498443023


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98024051 0.9301885  0.31698113]
epoch 1  0.7424700486393906


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.07140452637216071
1.4083541881071224


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98223037 0.9405315  0.26598465]
epoch 2  0.7295821734995455


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.06147701931210569
1.3299103508066592


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98162707 0.94855159 0.36391437]
epoch 3  0.7646976757202384


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.05086969752678858
1.2775014085385343


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.9842587  0.94283806 0.47968545]
epoch 4  0.8022607379980647


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.046789950305358735
1.236559078403322


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98774141 0.9622554  0.5       ]
epoch 5  0.8166656034181043


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.043172537109659
1.2036178429023932


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98588175 0.95693424 0.51313131]
epoch 6  0.8186491002528706


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.04158933812004154
1.1493122948207404


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98479853 0.9369195  0.46831364]
epoch 7  0.7966772242819288


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.03799514135416619
1.1055139632746425


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98302772 0.95594549 0.55701754]
epoch 8  0.8319969180965131


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.037420785305250025
1.0849909215431193


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98916222 0.96346439 0.54109589]
epoch 9  0.831240834902287


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.0342298370463246
1.04683626825621


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98820097 0.95988935 0.45726496]
epoch 10  0.8017850922871889


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.03499547591713462
1.0500466697055701


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98741948 0.96962445 0.4896    ]
epoch 11  0.8155479760227585


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.03247002873380083
1.0179406870500753


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98669405 0.9533503  0.53144654]
epoch 12  0.8238302954769695


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.03187441017329703
1.0016044104123751


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.99004611 0.9691066  0.55205047]
epoch 13  0.8370677253744819


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.030849636728834087
0.9815362797288925


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98388552 0.95918367 0.53373313]
epoch 14  0.8256007765521044


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.031092716827742478
0.9632775234307107


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98747856 0.95888358 0.54574639]
epoch 15  0.8307028430500275


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.028533542644123378
0.9528191916078512


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98793618 0.96469509 0.49319728]
epoch 16  0.8152761837992722


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.026345628248875076
0.9355557783214219


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.9900845  0.96662616 0.54195323]
epoch 17  0.83288796221857


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.02675367629500097
0.9253643707695438


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98394535 0.88207212 0.31734317]
epoch 18  0.727786879038761


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.025821397759526188
0.9262972120080886


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98558527 0.95454545 0.48897059]
epoch 19  0.8097004384816228


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.025161352243441244
0.9134770344935159


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.96539501 0.93309672 0.48911652]
epoch 20  0.7958694150295927


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.02525726828424623
0.9074636648659589


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98865492 0.95879805 0.40718563]
epoch 21  0.7848795317146197


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.024191323361416227
0.8973861644369621


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98744626 0.93551194 0.4287119 ]
epoch 22  0.7838900326181722


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.025464334163400215
0.8789351589258692


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98568437 0.95124717 0.5480315 ]
epoch 23  0.8283210113528625


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.023513530915055897
0.8818411959117198


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98228372 0.93363003 0.42580645]
epoch 24  0.7805734005661601


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.02387184485874528
0.8686525320033762


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98370029 0.94515088 0.52012384]
epoch 25  0.8163250038938644


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.023002387683958517
0.8678086452507385


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98638699 0.95238095 0.58249641]
epoch 26  0.8404214506272494


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.02245663369810819
0.8463761392957866


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98749409 0.95407878 0.47095761]
epoch 27  0.8041768266997519


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.02257519144678187
0.8544676285963774


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.99079317 0.96177696 0.5       ]
epoch 28  0.8175233752114098


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.023212214714208344
0.8526869966233536


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.99097777 0.94826617 0.43057325]
epoch 29  0.7899390611384242


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.02228988988224287
0.8423529009830046


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98786575 0.9604082  0.5846926 ]
epoch 30  0.8443221825939665


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.02292518113794966
0.8415574837759046


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.99025778 0.9627018  0.52173913]
epoch 31  0.8248995704136154


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.020192327042325493
0.836147179685579


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98875436 0.91969668 0.39466667]
epoch 32  0.7677059014302673


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.023240674202989482
0.8351727423298285


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98958738 0.96233609 0.52631579]
epoch 33  0.826079755203739


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.02117981333160746
0.821094731423522


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98721688 0.95764307 0.55847953]
epoch 34  0.8344464961537118


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.02043966197079284
0.8203621958621473


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98527176 0.95730918 0.51472471]
epoch 35  0.8191018841867344


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.0205261270349288
0.8191676253655608


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98929188 0.9676525  0.55625   ]
epoch 36  0.8377314593116864


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.019898651162405138
0.8133721636489676


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98674307 0.95355684 0.61386139]
epoch 37  0.8513871004432526


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.020617173224939493
0.8055611628526534


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98885014 0.9601032  0.50557621]
epoch 38  0.8181765165557878


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.02016914903186303
0.8078203572599061


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98807833 0.95595127 0.49230769]
epoch 39  0.8121124292142508


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.019245240251597684
0.8192848309690948


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98701075 0.94823901 0.46355685]
epoch 40  0.7996022026067154


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.020750895366157374
0.7927448482575847


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98807583 0.95960072 0.59097525]
epoch 41  0.846217268230555


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.020549601834708102
0.8005660801391582


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98783371 0.95987338 0.54045307]
epoch 42  0.8293867232940907


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.01882445713236263
0.8057364489294175


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98848995 0.9593918  0.52727273]
epoch 43  0.825051493881242


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.018169465158374355
0.7832809044258551


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.97939362 0.94121864 0.55721393]
epoch 44  0.8259420631837409


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.021643957827833474
0.7874100261675748


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.99139859 0.96299024 0.47035573]
epoch 45  0.8082481872672779


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.018990859931310673
0.7702703533782117


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.99103989 0.96307265 0.51580699]
epoch 46  0.8233065098748912


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.01899875932003532
0.7710869664773804


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.99135845 0.96335664 0.56458056]
epoch 47  0.8397652180151559


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.018424728560518775
0.782769501209259


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98558629 0.9569952  0.59324523]
epoch 48  0.8452755720942798


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.01862203159836766
0.7681780477844224


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.95366674 0.90270224 0.48247978]
epoch 49  0.7796162539413299


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.01983743801759313
0.7701779577490975


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.989027   0.96113602 0.58361775]
epoch 50  0.8445935917380867


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.016625303027291465
0.7561169203210171


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.9882617  0.95908408 0.53535354]
epoch 51  0.8275664381956457


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.017324783590298112
0.7512367935946834


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98956267 0.96548507 0.59383754]
epoch 52  0.8496284255891954


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.018804641269226995
0.7513650017780934


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98822518 0.95940752 0.56222548]
epoch 53  0.836619392331837


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.017232269258250187
0.7493849479625847


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98861382 0.96399254 0.54901961]
epoch 54  0.8338753212800861


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.019774744492419408
0.7565537287989437


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98812146 0.94572759 0.48734177]
epoch 55  0.8070636093405943


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.01779295022468366
0.74517422455537


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.99023167 0.9670025  0.56240602]
epoch 56  0.839880061409822


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.017257072081285903
0.761936207892836


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.99043598 0.96857568 0.62883436]
epoch 57  0.8626153386360403


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.01752135572274177
0.7474369903456504


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.9889579  0.96854876 0.61026352]
epoch 58  0.8559233961006782


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.018804328890444845
0.7291459704938611


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98656446 0.9173806  0.38133548]
epoch 59  0.7617601813672157


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.017294015731045935
0.7418709812682023


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.9890928  0.96772409 0.5922619 ]
epoch 60  0.8496929298661954


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.01704509979792324
0.7401127595400908


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98839023 0.96378363 0.51069182]
epoch 61  0.8209552279017628


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.0156945037314667
0.7242471886764317


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98892419 0.96102429 0.61079545]
epoch 62  0.853581310080772


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.017125738334549456
0.7241058032615474


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98754042 0.96509892 0.60648801]
epoch 63  0.8530424489945355


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.017973229594598048
0.7360144417266337


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98989986 0.96329504 0.60990712]
epoch 64  0.854367339846294


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.01760682454197697
0.7303220009320326


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98751625 0.96321564 0.52347084]
epoch 65  0.8247342432654685


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.016685229572331973
0.7124610017579683


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

[0.98130679 0.96305213 0.62447257]
epoch 66  0.8562771672974675


Step ...:   0%|          | 0/1948 [00:00<?, ?it/s]

0.01608342463578564
0.7145235084257332


Step ...:   0%|          | 0/541 [00:00<?, ?it/s]

KeyboardInterrupt: 

### evaluate_model(model, test_dataset)

In [None]:
random.randint(0, 256)

In [None]:
random.randint(0, 256)

In [4]:
random.randint(0, 256)

108

In [5]:
random.randint(0, 256)

182

In [6]:
random.randint(0, 256)

171