****************************************************************
# Neural Networks for Regression
****************************************************************

In [1]:
import numpy as np
import cait as ai
from pytorch_lightning import Trainer
from torchvision import transforms
import h5py
from cait.datasets import RemoveOffset, Normalize, DownSample, ToTensor, CryoDataModule
from cait.models import LSTMModule, nn_predict
from pytorch_lightning.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt
%config InlineBackend.figure_formats = ['svg']  # we need this for a suitable resolution of the plots

In [2]:
# some parameters
# nmbr_gpus = ... uncommment and put in trainer to use GPUs
path_h5 = 'test_data/efficiency_001.h5'
type = 'events'
keys = ['event', 'true_ph']
channel_indices = [[0], [0]]
feature_indices = [None, None]
feature_keys = ['event_ch0']
label_keys = ['true_ph_ch0']
norm_vals = {'event_ch0': [0, 1]}
down_keys = ['event_ch0']
down = 8
input_size = 8
nmbr_out = 1
device_name='cpu'
max_epochs = 30
save_naming = 'lstm-reg'

## Dataset und Model

In [3]:
# create the transforms
transforms = transforms.Compose([RemoveOffset(keys=feature_keys),
                                 Normalize(norm_vals=norm_vals),
                                 DownSample(keys=down_keys, down=down),
                                 ToTensor()])

In [4]:
# create data module and init the setup
dm = CryoDataModule(hdf5_path=path_h5,
                    type=type,
                    keys=keys,
                    channel_indices=channel_indices,
                    feature_indices=feature_indices,
                    transform=transforms)

In [5]:
dm.prepare_data(val_size=0.2,
                test_size=0.2,
                batch_size=8,
                dataset_size=None,
                nmbr_workers=0,  # set to number of CPUS on the machine
                only_idx=None,
                shuffle_dataset=True,
                random_seed=42,
                feature_keys=feature_keys,
                label_keys=label_keys,
                keys_one_hot=[])

In [6]:
dm.setup()

In [7]:
# create lstm clf
lstm = LSTMModule(input_size=input_size,
                  hidden_size=input_size * 10,
                  num_layers=2,
                  seq_steps=int(dm.dims[1] / input_size),  # downsampling is already considered in dm
                  device_name=device_name,
                  nmbr_out=nmbr_out,  # this is the number of labels
                  lr=1e-5,
                  label_keys=label_keys,
                  feature_keys=feature_keys,
                  is_classifier=False,
                  down=down,
                  down_keys=feature_keys,
                  norm_vals=norm_vals,
                  offset_keys=feature_keys)

## Tensorboard

:::{note}
**Tensorboard on Server without X-Forwarding**
If you work on a remote server that has X-forwarding deactivated, i.e. you don't have to option to show graphical elements, you can start the ssh connection with the additional -L flag:

    ssh -L 16006:127.0.0.1:6006 <SERVER_SSH_ADRESS>
    
Then your local machine listens to the standard port of tensorboard on the remote server and you can open the tensorboard interface in a browser on your local machine by typing http://127.0.0.1:16006/ in the address line.
:::

> %load_ext tensorboard

> %tensorboard --logdir=lightning_logs

## Training

In [10]:
# create callback to save the best model
checkpoint_callback = ModelCheckpoint(dirpath='callbacks',
                                      monitor='val_loss',
                                      filename=save_naming + '-{epoch:02d}-{val_loss:.2f}')

In [11]:
# create instance of Trainer
trainer = Trainer(max_epochs=max_epochs,
                  callbacks=[checkpoint_callback])
# keyword gpus=nmbr_gpus for GPU Usage
# keyword max_epochs for number of maximal epochs

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [12]:
# all training happens here
trainer.fit(model=lstm,
            datamodule=dm)


  | Name | Type   | Params
--------------------------------
0 | lstm | LSTM   | 80 K  
1 | fc1  | Linear | 20 K  


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…



HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…



HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

## Evaluation

In [13]:
# load best model
lstm.load_from_checkpoint(checkpoint_callback.best_model_path)
#lstm.load_from_checkpoint('callbacks/lstm-clf-epoch=28-val_loss=0.05.ckpt')

LSTMModule(
  (lstm): LSTM(8, 80, num_layers=2, batch_first=True)
  (fc1): Linear(in_features=20480, out_features=1, bias=True)
)

In [14]:
# run test set
result = trainer.test()
print(result)



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.0752),
 'train_loss': tensor(0.0531),
 'val_loss': tensor(0.0760)}
--------------------------------------------------------------------------------

[{'train_loss': 0.053097039461135864, 'val_loss': 0.0759606659412384, 'test_loss': 0.07515044510364532}]


In [15]:
# predictions with the model are made that way
f = h5py.File(dm.hdf5_path, 'r')
test_idx = dm.test_sampler.indices
test_idx.sort()
x = {feature_keys[0]: f[type][keys[0]][channel_indices[0][0], test_idx]}  # array of shape: (nmbr_events, nmbr_features)
y = np.array(f[type][keys[1]][channel_indices[1][0], test_idx])
prediction = lstm.predict(x).numpy()

# predictions can be saved with instance of EvaluationTools
print('RMS OF PREDICTION: ', np.sqrt(np.mean((prediction - y)**2)))
print('Best model: ', checkpoint_callback.best_model_path)
print('Predictions: ', prediction)

RMS OF PREDICTION:  0.2744119057615132
Best model:  /Users/felix/PycharmProjects/cait/docs/source/tutorials/callbacks/lstm-reg-epoch=11-val_loss=0.07.ckpt
Predictions:  [[0.45109764]
 [0.45154372]
 [0.45151246]
 [0.452013  ]
 [0.4522721 ]
 [0.45208302]
 [0.45176235]
 [0.4511507 ]
 [0.45101056]
 [0.45176452]
 [0.45148477]
 [0.4516974 ]
 [0.45161384]
 [0.4516814 ]
 [0.45174435]
 [0.45162496]
 [0.45147425]
 [0.4521565 ]
 [0.45063412]
 [0.45144284]]
