In [1]:
from rljax.algorithm import DQN
from rljax.trainer import Trainer

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from micro_price_trading.config import TWENTY_SECOND_DAY
from micro_price_trading import Preprocess, OptimalExecutionEnvironment

In [2]:
raw = Preprocess('TBT_TBF_data.csv', res_bin=6)
data = raw.process()

In [None]:
# 23,400 seconds between 9:30am and 4pm broken in 10 second increments

NUM_AGENT_STEPS = 5000
SEED = 0

env = OptimalExecutionEnvironment(
    data,
    risk_weights=(2, 1),
    trade_penalty=100,
    max_purchase=4,
    steps=TWENTY_SECOND_DAY,
    end_units_risk=TWENTY_SECOND_DAY,  # Ideally, this should be `TWENTY_SECOND_DAY//5*2`
    seed=SEED
)
env_test = env.copy_env()

algo = DQN(
    num_agent_steps=NUM_AGENT_STEPS,
    state_space=env.observation_space,
    action_space=env.action_space,
    seed=SEED,
    batch_size=256,
    start_steps=1000,
    update_interval=1,
    update_interval_target=400,
    eps_decay_steps=0,
    loss_type="l2",
    lr=1e-5,  # Have been messing around with this but doesn't seem to make a big difference
)

trainer = Trainer(
    env=env,
    env_test=env_test,
    algo=algo,
    log_dir="",
    num_agent_steps=NUM_AGENT_STEPS,
    eval_interval=2500,
    seed=SEED,
)
trainer.train()



5 1170 1165
10 1164 1160
15 1156 1155
20 1155 1150
25 1149 1145
30 1141 1140
35 1140 1135
40 1134 1130
45 1126 1125
50 1125 1120
55 1119 1115
60 1111 1110
65 1110 1105
70 1104 1100
75 1096 1095
80 1095 1090
85 1089 1085
90 1081 1080
95 1080 1075
100 1074 1070
105 1066 1065
110 1065 1060
115 1059 1055
120 1051 1050
125 1050 1045
130 1044 1040
140 1035 1030
145 1029 1025
150 1023 1020
155 1019 1015
160 1013 1010
165 1009 1005
170 1003 1000
175 999 995
180 993 990
185 989 985
190 983 980
195 979 975
200 973 970
205 969 965
210 963 960
215 959 955
220 953 950
225 949 945
230 943 940
235 939 935
240 933 930
245 929 925
250 923 920
255 919 915
270 905 900
275 899 895
280 893 890
285 889 885
290 883 880
295 879 875
300 873 870
305 869 865
310 863 860
315 859 855
320 853 850
325 849 845
330 843 840
335 839 835
340 833 830
345 829 825
350 823 820
355 819 815
360 813 810
365 809 805
370 803 800
375 799 795
380 793 790
385 789 785
390 783 780
395 779 775
400 773 770
405 769 765
410 763 760
415 75

975 198 195
980 194 190
985 188 185
990 184 180
995 178 175
1000 174 170
1005 168 165
1010 164 160
1015 158 155
1020 154 150
1025 148 145
1030 144 140
1035 138 135
1040 134 130
1045 128 125
1050 124 120
1055 118 115
1060 114 110
1065 108 105
1070 104 100
1075 98 95
1080 94 90
1085 88 85
1090 84 80
1095 78 75
1100 74 70
1105 68 65
1110 64 60
1115 58 55
1120 54 50
1125 48 45
1130 44 40
1135 38 35
1140 34 30
1145 28 25
1150 24 20
1155 18 15
1160 14 10
1165 8 5
5 1170 1165
10 1164 1160
15 1158 1155
20 1154 1150
25 1148 1145
30 1144 1140
35 1138 1135
40 1134 1130
45 1128 1125
50 1124 1120
55 1118 1115
60 1114 1110
65 1108 1105
70 1104 1100
75 1098 1095
80 1094 1090
85 1088 1085
90 1084 1080
95 1078 1075
100 1074 1070
105 1068 1065
110 1064 1060
115 1058 1055
120 1054 1050
125 1048 1045
130 1044 1040
135 1038 1035
140 1034 1030
145 1028 1025
150 1024 1020
155 1018 1015
160 1014 1010
165 1008 1005
170 1004 1000
175 998 995
180 994 990
185 988 985
190 984 980
195 978 975
200 974 970
205 968 96

890 283 280
895 279 275
900 273 270
905 269 265
910 263 260
915 259 255
920 253 250
925 249 245
930 243 240
935 239 235
940 233 230
945 229 225
950 223 220
955 219 215
960 213 210
965 209 205
970 203 200
975 199 195
980 193 190
985 189 185
990 183 180
995 179 175
1000 173 170
1005 169 165
1010 163 160
1015 159 155
1020 153 150
1025 149 145
1030 143 140
1035 139 135
1040 133 130
1045 129 125
1050 123 120
1055 119 115
1060 113 110
1065 109 105
1070 103 100
1075 99 95
1080 93 90
1085 89 85
1090 83 80
1095 79 75
1100 73 70
1105 69 65
1110 63 60
1115 59 55
1120 53 50
1125 49 45
1130 43 40
1135 39 35
1140 33 30
1145 29 25
1150 23 20
1155 19 15
1160 13 10
1165 9 5
5 1170 1165
10 1164 1160
15 1158 1155
20 1154 1150
25 1148 1145
30 1144 1140
35 1138 1135
40 1134 1130
45 1128 1125
50 1124 1120
55 1118 1115
60 1114 1110
65 1108 1105
70 1104 1100
75 1098 1095
80 1094 1090
85 1088 1085
90 1084 1080
95 1078 1075
100 1074 1070
105 1068 1065
110 1064 1060
115 1058 1055
120 1054 1050
125 1048 1045
130 

760 414 410
765 408 405
770 404 400
775 398 395
780 394 390
785 388 385
790 384 380
795 378 375
800 374 370
805 368 365
810 364 360
815 358 355
820 354 350
825 348 345
830 344 340
835 338 335
840 334 330
845 328 325
850 324 320
855 318 315
860 314 310
865 308 305
870 304 300
875 298 295
880 294 290
885 288 285
890 284 280
895 278 275
900 274 270
905 268 265
910 264 260
915 258 255
920 254 250
925 248 245
930 244 240
935 238 235
940 234 230
945 228 225
950 224 220
955 218 215
960 214 210
965 208 205
970 204 200
975 198 195
980 194 190
985 188 185
990 184 180
995 178 175
1000 174 170
1005 168 165
1010 164 160
1015 158 155
1020 154 150
1025 148 145
1030 144 140
1035 138 135
1040 134 130
1045 128 125
1050 124 120
1055 118 115
1060 114 110
1065 108 105
1070 104 100
1075 98 95
1080 94 90
1085 88 85
1090 84 80
1095 78 75
1100 74 70
1105 68 65
1110 64 60
1115 58 55
1120 54 50
1125 48 45
1130 44 40
1135 38 35
1140 34 30
1145 28 25
1150 24 20
1155 18 15
1160 14 10
1165 8 5
5 1170 1165
10 1164 11

830 344 340
835 338 335
840 334 330
845 328 325
850 324 320
855 318 315
860 314 310
865 308 305
870 304 300
875 298 295
880 294 290
885 288 285
890 284 280
895 278 275
900 274 270
905 268 265
910 264 260
915 258 255
920 254 250
925 248 245
930 244 240
935 238 235
940 234 230
945 228 225
950 224 220
955 218 215
960 214 210
965 208 205
970 204 200
975 198 195
980 194 190
985 188 185
990 184 180
995 178 175
1000 174 170
1005 168 165
1010 164 160
1015 158 155
1020 154 150
1025 148 145
1030 144 140
1035 138 135
1040 134 130
1045 128 125
1050 124 120
1055 118 115
1060 114 110
1065 108 105
1070 104 100
1075 98 95
1080 94 90
1085 88 85
1090 84 80
1095 78 75
1100 74 70
1105 68 65
1110 64 60
1115 58 55
1120 54 50
1125 48 45
1130 44 40
1135 38 35
1140 34 30
1145 28 25
1150 24 20
1155 18 15
1160 14 10
1165 8 5
5 1170 1165
10 1164 1160
15 1158 1155
20 1154 1150
25 1148 1145
30 1144 1140
35 1138 1135
40 1134 1130
45 1128 1125
50 1124 1120
55 1118 1115
60 1114 1110
65 1108 1105
70 1104 1100
75 1098 1

400 774 770
405 768 765
410 764 760
415 758 755
420 754 750
425 748 745
430 744 740
435 738 735
440 734 730
445 728 725
450 724 720
455 718 715
460 714 710
465 708 705
470 704 700
475 698 695
480 694 690
485 688 685
490 684 680
495 678 675
500 674 670
505 668 665
510 664 660
515 658 655
520 654 650
525 648 645
530 644 640
535 638 635
540 634 630
545 628 625
550 624 620
555 618 615
560 614 610
565 608 605
570 604 600
575 598 595
580 594 590
585 588 585
590 584 580
595 578 575
600 574 570
605 568 565
610 564 560
615 558 555
620 554 550
625 548 545
630 544 540
635 538 535
640 534 530
645 528 525
650 524 520
655 518 515
660 514 510
665 508 505
670 504 500
675 498 495
680 494 490
685 488 485
690 484 480
695 478 475
700 474 470
705 468 465
710 464 460
715 458 455
720 454 450
725 448 445
730 444 440
735 438 435
740 434 430
745 428 425
750 424 420
755 418 415
760 414 410
765 408 405
770 404 400
775 398 395
780 394 390
785 388 385
790 384 380
795 378 375
800 374 370
805 368 365
810 364 360
815 

275 898 895
280 894 890
285 888 885
290 884 880
295 878 875
300 874 870
305 868 865
310 864 860
315 858 855
320 854 850
325 848 845
330 844 840
335 838 835
340 834 830
345 828 825
350 824 820
355 818 815
360 814 810
365 808 805
370 804 800
375 798 795
380 794 790
385 788 785
390 784 780
395 778 775
400 774 770
405 768 765
410 764 760
415 758 755
420 754 750
425 748 745
430 744 740
435 738 735
440 734 730
445 728 725
450 724 720
455 718 715
460 714 710
465 708 705
470 704 700
475 698 695
480 694 690
485 688 685
490 684 680
495 678 675
500 674 670
505 668 665
510 664 660
515 658 655
520 654 650
525 648 645
530 644 640
535 638 635
540 634 630
545 628 625
550 624 620
555 618 615
560 614 610
565 608 605
570 604 600
575 598 595
580 594 590
585 588 585
590 584 580
595 578 575
600 574 570
605 568 565
610 564 560
615 558 555
620 554 550
625 548 545
630 544 540
635 538 535
640 534 530
645 528 525
650 524 520
655 518 515
660 514 510
665 508 505
670 504 500
675 498 495
680 494 490
685 488 485
690 

180 994 990
185 988 985
190 984 980
195 978 975
200 974 970
205 968 965
210 964 960
215 958 955
220 954 950
225 948 945
230 944 940
235 938 935
240 934 930
245 928 925
250 924 920
255 918 915
260 914 910
265 908 905
270 904 900
275 898 895
280 894 890
285 888 885
290 884 880
295 878 875
300 874 870
305 868 865
310 864 860
315 858 855
320 854 850
325 848 845
330 844 840
335 838 835
340 834 830
345 828 825
350 824 820
355 818 815
360 814 810
365 808 805
370 804 800
375 798 795
380 794 790
385 788 785
390 784 780
395 778 775
400 774 770
405 768 765
410 764 760
415 758 755
420 754 750
425 748 745
430 744 740
435 738 735
440 734 730
445 728 725
450 724 720
455 718 715
460 714 710
465 708 705
470 704 700
475 698 695
480 694 690
485 688 685
490 684 680
495 678 675
500 674 670
505 668 665
510 664 660
515 658 655
520 654 650
525 648 645
530 644 640
535 638 635
540 634 630
545 628 625
550 624 620
555 618 615
560 614 610
565 608 605
570 604 600
575 598 595
580 594 590
585 588 585
590 584 580
595 

In [None]:
pd.set_option('display.max_rows', 2000)
pd.set_option('display.max_columns', 2000)
df = env_test.portfolios_to_df(n=1)
df

In [None]:
df['penalty_trade'][6]

In [None]:
env_test._raw_actions[-1]

#### Raw actions input by DQN before they are zero centered

In [None]:
env_test._raw_actions[-1]

#### Rewards along with a flag for the type of reward

In [None]:
env_test._rewards[-1]

#### Raw observations seen by DQN

In [None]:
env_test._observations[-1]

### Raw format for accessing the portfolio history
Should ideally be gotten by `env_test.portfolio_history`

## FOR SOME REASON, THE LENGTH OF THIS IS NOWHERE NEAR WHAT IT SHOULD BE FOR THE BASE TRAINING ENV. I AM GETTING AROUND 17 ENTRIES FOR IT BUT THE ENV_TEST SEEMS TO BE CORRECT

In [None]:
env_test._portfolios[3][-1]

In [None]:
env_test.end_units_risk

In [None]:
env_test._period_risk.get(5, env_test.end_units_risk)

In [None]:
env_test.end_units_risk - env_test.current_portfolio.total_risk

In [None]:
env_test.prices_at_start

In [None]:
# print(env_test.step(2))
# env_test._portfolios[-1][-1]

### Not sure this plot is fully correct, when the shares hit this line perfectly, they fail to hit the risk plot perfectly. This also happens in reverse

In [None]:
env_test.plot()

### Case and point, shares hit well but risk doesn't

In [None]:
env_test.plot('risk_history')

### Count number of chosen trades
Can use `len([p.time for p in env_test.portfolio_history[-1] if (p.trade or p.penalty_trade)])` for all trades or `len([p.time for p in env_test.portfolio_history[-1] if (p.penalty_trade)])` for the penalty trades

In [None]:
len([p.time for p in env_test.portfolio_history[-1] if (p.trade)])

### Changes the `env_test._period_risk` dictionary to have the amount of risk we should have bought instead of the amount of risk remaining

In [None]:
dict(zip(env_test._period_risk.keys(), env_test.end_units_risk-np.array(list(env_test._period_risk.values()))))

In [None]:
list(zip(env_test._rewards[-1], env_test.portfolio_history[-1]))

In [None]:
env_test._observations[-1]

In [None]:
plt.plot(np.cumsum(env_test._rewards[-1]))

In [None]:
import numpy as np

In [None]:
env_test.risk_history[-1]

In [None]:
set(np.gradient(env_test.risk_history[-1]))

In [None]:
env_test.portfolio_history

In [None]:
np.argwhere(env_test.risk_history[-1] > env_test.end_units_risk)