In [1]:
import torch
import matplotlib.pyplot as plt
from model import PIBFilTransformer
from dataset import PIBDataset
from torch.utils.data import DataLoader
import lightning as L

In [2]:
data_filename = 'data/test/streaming/test.json'
# data_filename = 'data/person-in-bed-streaming-detection/train.json'
# checkpoint_path = 'results/transformer-cc/lightning_logs/version_15967234/checkpoints/epoch=9-step=241990.ckpt'
# checkpoint_path = 'results/transformer/lightning_logs/version_9156/checkpoints/epoch=19-step=483000.ckpt'
checkpoint_path = 'results/transformer/lightning_logs/version_16323998/checkpoints/epoch=2-step=72450.ckpt'

train_validate_test_split = [0, 1, 0]

pib_model = PIBFilTransformer.load_from_checkpoint(checkpoint_path)
pib_model.eval()

dataset_val = PIBDataset(mode='validate', data_path=data_filename, train_validate_test_split=train_validate_test_split)
loader_val = DataLoader(dataset_val, batch_size=256, shuffle=False)

In [None]:
print(len(dataset_val))

In [None]:
trainer = L.Trainer(default_root_dir='results/transformer', enable_checkpointing=False, logger=False)
returned_values = trainer.predict(pib_model, loader_val, ckpt_path=checkpoint_path, return_predictions=True)
outputs = torch.cat([returned_value[0] for returned_value in returned_values], dim=0)
labels = torch.cat([returned_value[1] for returned_value in returned_values], dim=0)
model_llrs = torch.cat([returned_value[2] for returned_value in returned_values], dim=0)
ts_list = torch.cat([returned_value[3] for returned_value in returned_values], dim=0)
prev_sums = torch.cat([returned_value[4] for returned_value in returned_values], dim=0)
subjects = torch.cat([returned_value[5] for returned_value in returned_values], dim=0)
# chunk_ids = torch.cat([returned_value[6] for returned_value in returned_values], dim=0)
chunk_ids = torch.arange(len(ts_list))

In [5]:
output_filename = 'streaming_test_output.csv'
with open(output_filename, 'w') as f:
    print('id,subject,timestamp,label', file=f)
    # print('chunk_id,label')
    id = 0
    for i in range(len(chunk_ids)):
        for j in range(len(ts_list[i])):
            if ts_list[i][j] < 0:
                break
            print(f'{id},{int(subjects[i])},{ts_list[i][j].item()},{outputs[i][j].item()}', file=f)
            id += 1

In [None]:
new_outputs = outputs
avg_latency_list = []
latency_lists = []
for i in range(len(new_outputs)):
    transitions = torch.arange(1, len(new_outputs[i]), device=new_outputs.device, dtype=torch.long)[(labels[i][1:] - labels[i][:-1]) != 0]
    latency_list = []
    for j in transitions:
        if ts_list[i][j] < 0:
            break
        k = int(j + 0)   # k gets assigned to j's pointer and changes j otherwise
        while k < len(new_outputs[i]):
            if torch.mean((new_outputs[i][k:k+15*250+1] == labels[i][j])*1.0) == 1.:
                break
            k += 1
        latency_list.append((k - j) / 250.)
    if len(latency_list) > 0:
        avg_latency_list.append(torch.mean(torch.tensor(latency_list)))
        latency_lists.append(latency_list)

In [None]:
add_to_list = []
for latency_list in latency_lists:
    for latency in latency_list:
        add_to_list.append(torch.exp(-2.*torch.pow(torch.log(1.-latency/60.), 2)))
print(torch.mean(torch.tensor(add_to_list)))

In [None]:
i = torch.randint(0, len(outputs), (1,)).item()
i = 9
fig, axs = plt.subplots(1, 2, squeeze=False, figsize=(15, 5))
axs[0, 0].plot(torch.arange(len(outputs[i][ts_list[i] >= 0.]))/250., model_llrs[i][ts_list[i] >= 0.], label='LLRs')
axs[0, 0].plot(torch.arange(len(labels[i][ts_list[i] >= 0.]))/250., prev_sum_new[i][ts_list[i] >= 0.], label='Avg LLRs')
# axs[0, 0].plot(torch.arange(len(labels[i][ts_list[i] >= 0.]))/250., new_outputs[i][ts_list[i] >= 0.], label='Outputs')
axs[0, 0].legend(fontsize='x-large')
axs[0, 0].grid()
axs[0, 0].set_xlabel('Time (s)', fontsize='x-large')
axs[0, 0].set_ylabel('Prediction', fontsize='x-large')
axs[0, 0].set_title(f'Index i = {i}', fontsize='x-large')

axs[0, 1].plot(torch.arange(len(outputs[i][ts_list[i] >= 0.]))/250., labels[i][ts_list[i] >= 0.], label='Original')
axs[0, 1].plot(torch.arange(len(labels[i][ts_list[i] >= 0.]))/250., new_outputs[i][ts_list[i] >= 0.], label='Output')
axs[0, 1].legend(fontsize='x-large')
axs[0, 1].grid()
# axs[0, 1].set_xlim([700., 760.])
axs[0, 1].set_xlabel('Time (s)', fontsize='x-large')
axs[0, 1].set_ylabel('Prediction', fontsize='x-large')
axs[0, 1].set_title(f'Index i = {i}', fontsize='x-large')

plt.show()

In [None]:
delta = torch.arange(0., 59.9, 0.01)
metric = torch.exp(-2.*(torch.log(1-delta/60.))**2)
plt.plot(delta, metric)
plt.grid()
plt.xlim([0., 10.])
plt.ylim([0.9, 1.])
plt.show()