In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch.nn as nn
import torch
import pandas as pd
import pytorch_lightning as pl
from bayesian_nnet.net_utils import initialize_pl_trainer, get_model_checkpoint_dir
from bayesian_nnet.data import IQDataModel
from bayesian_nnet.cnnlstm import AutoModClassifier
from bayesian_nnet.model_eval import init_snrid_accuracy, evaluate_model_predictions

In [18]:
checkpoint_file = "epoch=23-val_loss=0.00-other_metric=0.00.ckpt"

checkpoint_pth =\
    get_model_checkpoint_dir("frequentist_net").joinpath(checkpoint_file)

model = AutoModClassifier.load_from_checkpoint(checkpoint_pth)
trainer = pl.Trainer()
iqdata = IQDataModel(batch_size=256)

predictions = trainer.predict(model,
                              datamodule=iqdata)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
2025-04-04 08:35:39.835953: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-04 08:35:40.161507: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743770140.399494    2518 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743770140.426165    2518 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-04 08:35:40.801514:

Predicting: |                                                                                        | 0/? [00…

In [31]:
from bayesian_nnet.model_eval import ModelPredictionFormatter
from sklearn.metrics import classification_report, confusion_matrix
pred_fmt = ModelPredictionFormatter()
mode_conf_preds = pd.concat([pred_fmt(elem) for elem in predictions])

In [45]:
predictions

[(tensor([[1.4339e-08, 3.4409e-09, 1.2186e-03,  ..., 2.5477e-08, 2.5476e-08,
           1.7060e-09],
          [1.6667e-15, 5.4265e-13, 4.3293e-09,  ..., 8.1510e-10, 4.7055e-10,
           2.0651e-15],
          [1.5796e-10, 9.5767e-10, 3.6733e-10,  ..., 3.0723e-10, 6.8684e-11,
           7.7904e-16],
          ...,
          [2.2483e-05, 2.8529e-03, 9.9408e-01,  ..., 8.7199e-06, 6.6354e-07,
           4.2769e-09],
          [2.9650e-06, 9.9194e-01, 4.1525e-03,  ..., 2.2178e-06, 6.9022e-08,
           2.9219e-10],
          [6.2295e-12, 9.4953e-11, 4.9879e-06,  ..., 5.2547e-09, 2.1382e-07,
           1.8174e-11]]),
  tensor([14,  9,  4, 14, 16,  6,  6,  1, 16,  5,  2, 12,  6, 12,  1, 11,  2, 11,
           8,  4,  6,  0, 14,  8, 16, 13, 13,  2,  8, 13, 11,  7,  5, 10, 15, 16,
          15, 14,  1, 11,  0,  2,  2, 15,  4,  5,  7,  8, 13,  7,  5,  5,  6,  7,
           6,  6, 15, 15,  2, 16, 11,  5,  5, 15,  6,  0,  0, 12,  8,  2,  8, 10,
          13,  5, 13,  5,  2,  7,  7, 11, 14,  7,

In [32]:
snrid_clf_report = dict()
snrid_conf_mat = dict()
modeids = list(mode_conf_preds.columns[:18])

for snrid in mode_conf_preds["snrid"].value_counts().keys():

    select_row = mode_conf_preds["snrid"] == snrid

    truemodeid =\
        mode_conf_preds.loc[select_row, "truemodeid"].values

    predictedmodeid =\
        mode_conf_preds.loc[select_row, "predictedmodeid"].values

    snrid_clf_report[snrid] =\
        classification_report(truemodeid,
                              predictedmodeid,
                              zero_division=0,
                              output_dict=True)

    cur_conf_mat =\
        confusion_matrix(truemodeid,
                         predictedmodeid,
                         normalize="true")

    snrid_conf_mat[snrid] =\
        pd.DataFrame(cur_conf_mat,
                     index=modeids,
                     columns=modeids)

ValueError: Shape of passed values is (16, 16), indices imply (18, 18)

In [29]:

snrid_clf_report

{'am': {'precision': 0.9343434343434344,
  'recall': 0.8851674641148325,
  'f1-score': 0.9090909090909091,
  'support': 209.0},
 'dominoex11': {'precision': 0.9420289855072463,
  'recall': 0.9605911330049262,
  'f1-score': 0.9512195121951219,
  'support': 203.0},
 'fax': {'precision': 0.9885057471264368,
  'recall': 0.9555555555555556,
  'f1-score': 0.9717514124293786,
  'support': 180.0},
 'morse': {'precision': 1.0,
  'recall': 0.9940476190476191,
  'f1-score': 0.9970149253731343,
  'support': 168.0},
 'mt63_1000': {'precision': 0.9829545454545454,
  'recall': 0.9942528735632183,
  'f1-score': 0.9885714285714285,
  'support': 174.0},
 'navtex': {'precision': 1.0,
  'recall': 0.9896907216494846,
  'f1-score': 0.9948186528497409,
  'support': 194.0},
 'olivia16_1000': {'precision': 0.9856459330143541,
  'recall': 1.0,
  'f1-score': 0.9927710843373494,
  'support': 206.0},
 'olivia16_500': {'precision': 0.7594339622641509,
  'recall': 0.8214285714285714,
  'f1-score': 0.7892156862745098