#### Imports

In [None]:
from os import chdir
import matplotlib.pyplot as plt
%matplotlib inline
chdir('../')

from src.pulsar_analysis.train_neural_network_model import ImageMaskPair
from src.pulsar_analysis.preprocessing import PrepareFreqTimeImage, BinarizeToMask
from src.pulsar_analysis.postprocessing import DelayGraph,LineClassifier,ConnectedComponents,FitSegmentedTraces

from src.pulsar_analysis.train_neural_network_model import TrainImageToMaskNetworkModel,ImageToMaskDataset,InMaskToMaskDataset
from src.pulsar_analysis.neural_network_models import UNet, CustomLossUNet, UNetFilter, FilterCNN, CustomLossSemanticSeg, CNN1D, WeightedBCELoss

from src.pulsar_analysis.train_neural_network_model import TrainSignalToLabelModel,SignalToLabelDataset
from src.pulsar_analysis.neural_network_models import OneDconvEncoder,Simple1DCnnClassifier

from pulsar_simulation.generate_data_pipeline import generate_example_payloads_for_training
from src.pulsar_analysis.pipeline_methods import ImageDataSet, ImageReader,LabelDataSet,LabelReader,PipelineImageToCCtoLabels, PipelineImageToMask, PipelineImageToFilterToCCtoLabels

#### Preset Image Mask InMask Label loaders

- Generate training and test data

In [None]:
#: Run this node only if training and test data needs to be generated
generate_example_payloads_for_training(tag='train_v0_',
                                       num_payloads=100,
                                       plot_a_example=True,
                                       param_folder='./syn_data/runtime/',
                                       payload_folder='./syn_data/payloads/',
                                       num_cpus=10 #: choose based on the number of nodes/cores in your system
                                       )

generate_example_payloads_for_training(tag='test_v0_',
                                       num_payloads=100,
                                       plot_a_example=True,
                                       param_folder='./syn_data/runtime/',
                                       payload_folder='./syn_data/payloads/',
                                       num_cpus=10 #: choose based on the number of nodes/cores in your system
                                       )

In [None]:
image_preprocessing_engine = PrepareFreqTimeImage(
                                                do_rot_phase_avg=True,
                                                do_binarize=False,
                                                do_resize=True,
                                                resize_size=(128,128),
                                                )
mask_preprocessing_engine = PrepareFreqTimeImage(
                                                do_rot_phase_avg=True,
                                                do_binarize=True,
                                                do_resize=True,
                                                resize_size=(128,128),
                                                binarize_engine = BinarizeToMask(binarize_func="thresh")#BinarizeToMask(binarize_func='gaussian_blur') # or 'exponential'
                                                )

cnn_model_to_make_mask_path: str = './syn_data/model/trained_UNet_test_v0.pt'
mask_maker_engine = PipelineImageToMask(
                                image_to_mask_network=UNet(),
                                trained_image_to_mask_network_path=cnn_model_to_make_mask_path,                     
                                )

signal_maker_engine = DelayGraph()

label_reader_engine = LabelReader()


- Load image mask inmask and labels using the engines

In [None]:
idx = 0

image_payload_file_path = './syn_data/payloads/' + 'train_v0_'+str(idx)+'_payload_detected.json'
image_preprocessing_engine.plot(payload_address=image_payload_file_path)

mask_payload_file_path = './syn_data/payloads/' + 'train_v0_'+str(idx)+'_payload_flux.json'
mask_preprocessing_engine.plot(payload_address=mask_payload_file_path)

#: Normalize the images before passing it to mask engine as we are using a CNN to perform the operation
image = image_preprocessing_engine(payload_address=image_payload_file_path)
image = image - min(image.flatten())
image = image / max(image.flatten())
mask_maker_engine.plot(image=image)

signal_maker_engine.plot(dispersed_freq_time=mask_maker_engine(image=image))
print(f'Label is {label_reader_engine(filename=image_payload_file_path)}')

#### Train CNNs

- Setup Training datasets    
    - Image-Mask pair dataset
    - InMask-Mask pair dataset
    - Signal-Label pair dataset

In [None]:
image_tag='train_v0_*_payload_detected.json' #: '*' in the name is the index place holder of a image in the image set
image_directory='./syn_data/payloads/'

mask_tag = 'train_v0_*_payload_flux.json'
mask_directory='./syn_data/payloads/'

image_mask_train_dataset = ImageToMaskDataset(
                        image_tag = image_tag,
                        mask_tag= mask_tag,
                        image_directory = image_directory,
                        mask_directory = mask_directory,
                        image_engine=image_preprocessing_engine,
                        mask_engine=mask_preprocessing_engine
                        )
inmask_mask_train_dataset = InMaskToMaskDataset(
                        image_tag = image_tag,
                        mask_tag= mask_tag,
                        image_directory = image_directory,
                        mask_directory = mask_directory,
                        mask_maker_engine=mask_maker_engine,
                        image_engine=image_preprocessing_engine,
                        mask_engine=mask_preprocessing_engine
                        )

signal_label_train_dataset = SignalToLabelDataset(mask_tag=mask_tag,
                                            mask_directory=mask_directory,
                                            mask_engine=mask_preprocessing_engine,
                                            )

- Plot pair from datasets

In [None]:
idx = 0
image_mask_train_dataset.plot(index=idx)
inmask_mask_train_dataset.plot(index=idx)
signal_label_train_dataset.plot(index=idx)

- Instantiate the trainers

In [None]:
store_trained_model_image2mask_at = './syn_data/model/trained_UNet_test_v0.pt'
image2mask_network_trainer = TrainImageToMaskNetworkModel(
                                model=UNet(),
                                num_epochs=10,
                                store_trained_model_at=store_trained_model_image2mask_at,
                                loss_criterion = WeightedBCELoss(pos_weight=3,neg_weight=1)                                
                                )

store_trained_model_inmask2mask_at = './syn_data/model/trained_FilterCNN_test_v0.pt'
inmask2mask_network_trainer = TrainImageToMaskNetworkModel(
                                model= FilterCNN(),
                                num_epochs=3,
                                store_trained_model_at=store_trained_model_inmask2mask_at,
                                loss_criterion = WeightedBCELoss(pos_weight=1,neg_weight=1)                               
                                )

store_trained_model_signal2label_at: str = './syn_data/model/trained_CNN1D_test_v0.pt'
signal2label_network_trainer = TrainSignalToLabelModel(
                                model=CNN1D(),
                                num_epochs=20,
                                loss_criterion=WeightedBCELoss(pos_weight=1,neg_weight=1),
                                store_trained_model_at=store_trained_model_signal2label_at,                                                                
                                )





- Start training

In [None]:
#image2mask_network_trainer(image_mask_pairset=image_mask_train_dataset)

In [None]:
#inmask2mask_network_trainer(image_mask_pairset=inmask_mask_train_dataset)

In [None]:
#signal2label_network_trainer(signal_label_pairset=signal_label_train_dataset)

- Setup Test datasets
    - Image-Mask pair dataset
    - InMask-Mask pair dataset
    - Signal-Label pair dataset

In [None]:
image_tag='test_v0_*_payload_detected.json' #: '*' in the name is the index place holder of a image in the image set
image_directory='./syn_data/payloads/'

mask_tag = 'test_v0_*_payload_flux.json'
mask_directory='./syn_data/payloads/'

image_mask_test_dataset = ImageToMaskDataset(
                        image_tag = image_tag,
                        mask_tag= mask_tag,
                        image_directory = image_directory,
                        mask_directory = mask_directory,
                        image_engine=image_preprocessing_engine,
                        mask_engine=mask_preprocessing_engine
                        )
inmask_mask_test_dataset = InMaskToMaskDataset(
                        image_tag = image_tag,
                        mask_tag= mask_tag,
                        image_directory = image_directory,
                        mask_directory = mask_directory,
                        mask_maker_engine=mask_maker_engine,
                        image_engine=image_preprocessing_engine,
                        mask_engine=mask_preprocessing_engine
                        )

signal_label_test_dataset = SignalToLabelDataset(mask_tag=mask_tag,
                                            mask_directory=mask_directory,
                                            mask_engine=mask_preprocessing_engine,
                                            )

- Plot pair from datasets

In [None]:
idx =50
image_mask_test_dataset.plot(index=idx)
inmask_mask_test_dataset.plot(index=idx)
signal_label_test_dataset.plot(index=idx)

- Start testing on test dataset

In [None]:
idx = 0
image = image_mask_test_dataset[idx][0]
mask = image_mask_test_dataset[idx][1]
pred = image2mask_network_trainer.test_model(image=image,plot_pred=True)
pred_filtered = inmask2mask_network_trainer.test_model(image=mask,plot_pred=True)
label = signal2label_network_trainer.test_model(mask=mask.squeeze().detach().numpy(),plot_pred=True)


#### Imports for pipelines

In [None]:
from src.pulsar_analysis.information_packet_formats import Payload
from src.pulsar_analysis.pipeline_methods import ImageDataSet, ImageReader, PipelineImageToDelGraphtoIsPulsar,PipelineImageToFilterDelGraphtoIsPulsar,LabelDataSet,LabelReader

#### Analysis Pipelines

- Instantiate Pipelines for detecting only pulsars
    - Pipeline: Segment -> DelayGraph -> Label
    - Pipeline: Segment -> Filtered Segment -> DelayGraph -> Label

In [None]:
ppl1 = PipelineImageToDelGraphtoIsPulsar(image_to_mask_network=UNet(),
                                        trained_image_to_mask_network_path=store_trained_model_image2mask_at,
                                        signal_to_label_network=CNN1D(),
                                        trained_signal_to_label_network=store_trained_model_signal2label_at)

ppl1f = PipelineImageToFilterDelGraphtoIsPulsar(image_to_mask_network=UNet(),
                                        trained_image_to_mask_network_path=store_trained_model_image2mask_at,
                                        mask_filter_network=FilterCNN(),
                                        trained_mask_filter_network_path=store_trained_model_inmask2mask_at,
                                        signal_to_label_network=CNN1D(),
                                        trained_signal_to_label_network=store_trained_model_signal2label_at)

- Set the datasets for testing

In [None]:
im_set = ImageDataSet(image_tag=image_tag,image_directory=image_directory,image_reader_engine=ImageReader(file_type=Payload([]),do_average=True))
m_set = ImageDataSet(image_tag=mask_tag,image_directory=mask_directory,image_reader_engine=ImageReader(file_type=Payload([]),do_average=True,do_binarize=True))
label_set = LabelDataSet(image_tag=image_tag,image_directory=image_directory,label_reader_engine=LabelReader(file_type=Payload([])))

- Test on test dataset

In [None]:
ppl1.display_results_in_batch(image_data_set=im_set,mask_data_set=m_set,label_data_set=label_set,randomize=True,ids_toshow=[71,96])

In [None]:
ppl1f.display_results_in_batch(image_data_set=im_set,mask_data_set=m_set,label_data_set=label_set,randomize=True,ids_toshow=[71,96])

- Instantiate Pipelines for detecting different categories
    - Pipeline: Segment -> CC -> Categories
    - Pipeline: Segment -> Filtered Segment -> CC -> Categories

In [None]:
ppl2 = PipelineImageToCCtoLabels(image_to_mask_network=UNet(),
                                trained_image_to_mask_network_path=store_trained_model_image2mask_at,
                                min_cc_size_threshold=5
                                )
ppl2f = PipelineImageToFilterToCCtoLabels(image_to_mask_network=UNet(),
                                trained_image_to_mask_network_path=store_trained_model_image2mask_at,
                                mask_filter_network=FilterCNN(),
                                trained_mask_filter_network_path=store_trained_model_inmask2mask_at,
                                min_cc_size_threshold=5)

- Test on test dataset

In [None]:
ppl2.display_results_in_batch(image_data_set=im_set,mask_data_set=m_set,label_data_set=label_set,randomize=True,ids_toshow=[71,96])

In [None]:
ppl2f.display_results_in_batch(image_data_set=im_set,mask_data_set=m_set,label_data_set=label_set,randomize=True,ids_toshow=[71,96])

- Measure Accuracy in test dataset

In [None]:
#ppl1f.validate_efficiency(image_data_set=im_set,label_data_set=label_set)

#### Test pipelines in real world

- Load data

In [None]:
import numpy as np
image_directory_npy ='path_to_real_image_data' #: load numpy memmap array containing real pulsar dispersion graphs. If not then design your own dataloader class 
label_directory_npy ='path_to_real_label_data' #: load numpy  array containing corrsponding label. If not then design your own dataloader class 
data = np.load(file=image_directory_npy,mmap_mode='r')
data_label = np.load(file=label_directory_npy,mmap_mode='r')
offset = 5000
size_of_set = 500
data_subset = data[offset+1:offset+size_of_set,:,:]
data_label_subset = data_label[offset+1:offset+size_of_set]

- Deploy Pipelines

In [None]:
ppl1.test_on_real_data_from_npy_files(image_data_set=data_subset,image_label_set=data_label_subset,plot_details=True,plot_randomly=True,batch_size=2)

In [None]:
ppl1f.test_on_real_data_from_npy_files(image_data_set=data_subset[8:10,:,:],image_label_set=data_label_subset[8:10],plot_details=True,plot_randomly=True,batch_size=2)

In [None]:
ppl2.test_on_real_data_from_npy_files(image_data_set=data_subset,image_label_set=data_label_subset,plot_randomly=True,batch_size=2)

In [None]:
ppl2f.test_on_real_data_from_npy_files(image_data_set=data_subset[8:10,:,:],image_label_set=data_label_subset[8:10],plot_randomly=True,batch_size=2)