In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import random

import sys
sys.path.append("../")
sys.path.append("../../src")
sys.path.append("../../")

from src.train import Trainer
from src.nn.net import Net, get_new_net, load_net
from src.data.plot_light_curve import plot_curves
from src.config import Config, DataConfig, FilterConfig, AugmentationConfig, PACKAGE_PATH

In [3]:
import os
FOLDER_NAME = "Net_14_3_2023"
for f in ["models", "datasets","configurations"]:
    os.makedirs(f"{PACKAGE_PATH}/output/{f}/{FOLDER_NAME}", exist_ok=True)

In [53]:
cfg = Config()

cfg.net_config.device = "cpu"
cfg.net_config.name = "Net_13_3_2023_v1"
cfg.net_config.save_path = f"{PACKAGE_PATH}/output/models/Net_13_3_2023_v1/"

cfg.data_config.path = f"{PACKAGE_PATH}/resources/mmt_13_3_2023/"
cfg.data_config.labels = ["cz_3", "falcon_9", "atlas",  "h2a", "globalstar"]
cfg.data_config.convert_to_mag = False
cfg.data_config.number_of_training_examples_per_class = 1000000
cfg.data_config.augmentation = None
cfg.data_config.filter = FilterConfig(
    n_bins= 30,
    n_gaps= 10,
    gap_size= 5, 
    rms_ratio= 0.,
    non_zero_ratio= 0.8
)
cfg.data_config.validation_split = 0.2


SAMPLER = True

LOAD = False
SEED = None
DATA_SEED = f"{cfg.data_config.filter.n_bins}_{cfg.data_config.filter.n_gaps}_{cfg.data_config.filter.gap_size}_{int(cfg.data_config.filter.non_zero_ratio * 10)}"
CHECKPOINT = "latest"

In [54]:

trainer = Trainer(None)
net = None

dataset_path = f"{PACKAGE_PATH}/output/datasets/{FOLDER_NAME}"

if os.path.exists(f"{dataset_path}/{DATA_SEED}"):
    trainer.load_data_from_file(f"{dataset_path}/{DATA_SEED}")
else:
    trainer.load_data(cfg.data_config)
    os.makedirs(f"{dataset_path}/{DATA_SEED}", exist_ok=True)
    trainer.save_data(f"{dataset_path}/{DATA_SEED}")



if LOAD:
    net = load_net(cfg, seed=SEED, checkpoint=CHECKPOINT)
    trainer.net = net
    trainer.load_data_from_file(f"{dataset_path}/{DATA_SEED}")
else:
    net = get_new_net(cfg)
    with open(f"{PACKAGE_PATH}/output/configurations/{FOLDER_NAME}/{net.name}.json", "w") as f:
        print(cfg.to_json(), file=f)
    SEED = cfg.seed
    trainer.net = net
    
if SAMPLER:
    trainer.add_sampler()

Folder /home/bach/Desktop/work/classification_of_light_curves/resources/mmt_13_3_2023/: 100%|██████████| 5/5 [00:00<00:00, 140.13it/s]


Label: falcon_9 5660 examples.
Label: h2a 5863 examples.
Label: globalstar 42174 examples.
Label: atlas 16704 examples.
Label: cz_3 26624 examples.
-------------- Filtered ---------------
Label: falcon_9 2205, 5660 examples.
Label: h2a 2411, 5863 examples.
Label: globalstar 4420, 42174 examples.
Label: atlas 2759, 16704 examples.
Label: cz_3 10169, 26624 examples.
Training set: 17570
Validation set: 4394
SEED: 274534
middle_dim 370


In [47]:
SEED

907099

In [55]:
trainer.train(200, 128,tensorboard_on=True, save_interval=50, print_on=False)
# net.save()

Training: 100%|██████████| 200/200 [01:02<00:00,  3.18it/s]


In [57]:
trainer.evaluate(cfg.data_config.labels)

0 cz_3
1 falcon_9
2 atlas
3 h2a
4 globalstar
Train:
	Loss: 0.00022455869849019768
	Acc: 99.6812749003984
Validation:
	Loss: 0.0425152990609459
	Acc: 75.96722803823396
-----------------------------------------

        Label  cz_3  falcon_9  atlas  h2a  globalstar
0        cz_3  1569       163    199   62          41
1    falcon_9   140       266     24    9           2
2       atlas   123        23    369   17          20
3         h2a    28        14     31  354          56
4  globalstar    31         7     26   40         780

-----------------------------------------

                cz_3   falcon_9      atlas        h2a  globalstar
Precision  77.138643  60.317460  66.847826  73.291925   88.235294
Recall     82.971973  56.236786  56.856703  73.443983   86.763070

-----------------------------------------



### V1
- no sampler
- n_bins= 30,
- n_gaps= 10,
- gap_size= 10, 
- non_zero_ratio= 0.6

                cz_3   falcon_9      atlas        h2a  globalstar
Precision  36.662356  70.914327  56.074766  67.330677   75.521850
Recall     88.537840  17.106634  24.747371  33.410214   87.384883

Training set: 10000
Validation set: 24841


### V2
- no sampler
- n_bins= 30,
- n_gaps= 10,
- gap_size= 5, 
- non_zero_ratio= 0.6

        Label  cz_3  falcon_9  atlas   h2a  globalstar
0        cz_3  5280      3335   2529  1042         256
1    falcon_9   175       826     83    45          14
2       atlas   238       305   1106   180          89
3         h2a    52        65    186   828         121
4  globalstar   248       148    483   614        4759

-----------------------------------------

                cz_3   falcon_9      atlas        h2a  globalstar
Precision  42.436907  72.265967  57.664234  66.134185   76.119642
Recall     88.102787  17.653345  25.210850  30.564784   90.837946

Training set: 10000
Validation set: 23007

### V3
- no sampler
- n_bins= 30,
- n_gaps= 10,
- gap_size= 5, 
- non_zero_ratio= 0.8

Training set: 9694
Validation set: 12276

  Label  cz_3  falcon_9  atlas  h2a  globalstar
0        cz_3  4081      1757   1512  588         234
1    falcon_9    72       293     49   20           8
2       atlas    88        92    494   42          43
3         h2a    19        23     48  338          55
4  globalstar    79        54    125  253        1909

-----------------------------------------

                cz_3   falcon_9      atlas        h2a  globalstar
Precision  49.938815  66.289593  65.085639  69.979296   78.884298
Recall     94.053929  13.204146  22.172352  27.236100   84.882170

-----------------------------------------

### V34
- no sampler
- n_bins= 30,
- n_gaps= 10,
- gap_size= 5, 
- non_zero_ratio= 0.8

Training set: 17574
Validation set: 4396

        Label  cz_3  falcon_9  atlas  h2a  globalstar
0        cz_3  1600       118    120   80         117
1    falcon_9   143       242     24   13          20
2       atlas   150        37    288   38          39
3         h2a    59        18     35  288          83
4  globalstar    58        11     29   50         736

-----------------------------------------

                cz_3   falcon_9      atlas        h2a  globalstar
Precision  78.624079  54.751131  52.173913  59.627329   83.257919
Recall     79.601990  56.807512  58.064516  61.407249   73.969849

-----------------------------------------


### V34
- sampler
- n_bins= 30,
- n_gaps= 10,
- gap_size= 5, 
- non_zero_ratio= 0.8

Training set: 17574
Validation set: 4396


----------------------------------------

        Label  cz_3  falcon_9  atlas  h2a  globalstar
0        cz_3  1585        99    148   83         120
1    falcon_9   138       253     26   16           9
2       atlas   165        34    285   28          40
3         h2a    63        12     37  294          77
4  globalstar    67        17     31   48         721

-----------------------------------------

                cz_3   falcon_9      atlas        h2a  globalstar
Precision  77.886978  57.239819  51.630435  60.869565   81.561086
Recall     78.543112  60.963855  54.079696  62.686567   74.560496

-----------------------------------------


### V5

- sampler
- n_bins= 30,
- n_gaps= 10,
- gap_size= 5, 
- non_zero_ratio= 0.8


Label: falcon_9 2205, 5660 examples.
Label: h2a 2411, 5863 examples.
Label: globalstar 4420, 42174 examples.
Label: atlas 2759, 16704 examples.
Label: cz_3 10169, 26624 examples.
Training set: 17570
Validation set: 4394

-----------------------------------------

        Label  cz_3  falcon_9  atlas  h2a  globalstar
0        cz_3  1569       163    199   62          41
1    falcon_9   140       266     24    9           2
2       atlas   123        23    369   17          20
3         h2a    28        14     31  354          56
4  globalstar    31         7     26   40         780

-----------------------------------------

                cz_3   falcon_9      atlas        h2a  globalstar
Precision  77.138643  60.317460  66.847826  73.291925   88.235294
Recall     82.971973  56.236786  56.856703  73.443983   86.763070

-----------------------------------------