In [1]:
import os
import numpy as np 
import zstandard
from tqdm import tqdm
import torch
from scipy import signal
import matplotlib
import matplotlib.pyplot as plt
from timeit import default_timer as timer
import torchvision
from torchvision import datasets, models, transforms
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from gamutrf_dataset import * 

from gamutrf.sample_reader import get_reader
from gamutrf.utils import parse_filename 

In [2]:
label_dirs= {
    'drone': ([
        'data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/',
        'data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/'
    ]), 
    'wifi_2_4': ([
        'data/gamutrf-pdx/07_21_2022/wifi_2_4/'
    ]), 
    'wifi_5': ([
        'data/gamutrf-pdx/07_21_2022/wifi_5/'
    ])
}

In [3]:
# labeled_filenames = labeled_files(label_dirs)
# idx = idx_info(labeled_filenames)
transform = transforms.Compose([
                transforms.ToTensor(), 
                transforms.Resize((256, 256))
            ])
dataset = GamutRFDataset(label_dirs, sample_secs=0.02, nfft=512, transform=transform)

label='drone', 131 files
label='wifi_2_4', 263 files
label='wifi_5', 366 files


 60%|████████████████████████████████████████████████████████████████████▋                                             | 79/131 [00:00<00:00, 777.69it/s]

loading data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/gamutrf_recording_ettus_directional-split_gain70_1655312998_5735000000Hz_20971520sps.s16.zst_0.02.npy; 0/131 time = 0.003083745948970318
loading data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/gamutrf_recording_ettus_directional-split_gain70_1655311950_5735000000Hz_20971520sps.s16.zst_0.02.npy; 1/131 time = 0.001958506996743381
loading data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/gamutrf_recording_ettus_directional-split_gain70_1655312216_5735000000Hz_20971520sps.s16.zst_0.02.npy; 2/131 time = 0.0019061639904975891
loading data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/gamutrf_recording_ettus_directional-split_gain70_1655312014_5735000000Hz_20971520sps.s16.zst_0.02.npy; 3/131 time = 0.0016190570313483477
loading data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/gamutrf_recording_ettus_directional-split_gain70_165531

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 131/131 [00:00<00:00, 448.05it/s]


loading data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/gamutrf_recording_ettus_directional-split_gain45_1653589152_5735000000Hz_20971520sps.s16.zst_0.02.npy; 90/131 time = 0.004266351985279471
loading data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/gamutrf_recording_ettus_directional-split_gain45_1653589098_5735000000Hz_20971520sps.s16.zst_0.02.npy; 91/131 time = 0.004175602982286364
loading data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/gamutrf_recording_ettus_directional-split_gain45_1653589857_5735000000Hz_20971520sps.s16.zst_0.02.npy; 92/131 time = 0.004105524974875152
loading data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/gamutrf_recording_ettus_directional-split_gain45_1653590972_5735000000Hz_20971520sps.s16.zst_0.02.npy; 93/131 time = 0.003374330000951886
loading data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/gamutrf_recording_ettus_directional-sp

  0%|                                                                                                                            | 0/263 [00:00<?, ?it/s]

loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658454977_2420000000Hz_20971520sps.s16.zst_0.02.npy; 0/263 time = 0.0012469350476749241
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658446723_2420000000Hz_20971520sps.s16.zst_0.02.npy; 1/263 time = 0.0007910719723440707
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658445024_2420000000Hz_20971520sps.s16.zst_0.02.npy; 2/263 time = 0.0006717059877701104
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658450097_2440000000Hz_20971520sps.s16.zst_0.02.npy; 3/263 time = 0.0006774489884264767
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658452126_2480000000Hz_20971520sps.s16.zst_0.02.npy; 4/263 time = 0.0006472999812103808
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658454149_2420000000Hz_209

 14%|████████████████                                                                                                  | 37/263 [00:00<00:00, 268.75it/s]

loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658455218_2420000000Hz_20971520sps.s16.zst_0.02.npy; 37/263 time = 0.002067133958917111
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658444004_2420000000Hz_20971520sps.s16.zst_0.02.npy; 38/263 time = 0.001238660013768822
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658447239_2400000000Hz_20971520sps.s16.zst_0.02.npy; 39/263 time = 0.0010750270448625088
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658444664_2420000000Hz_20971520sps.s16.zst_0.02.npy; 40/263 time = 0.0010672729695215821
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658452165_2440000000Hz_20971520sps.s16.zst_0.02.npy; 41/263 time = 0.0010924480156973004
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658451356_2440000000Hz_

 57%|████████████████████████████████████████████████████████████████▉                                                | 151/263 [00:00<00:00, 711.30it/s]

loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658446023_2400000000Hz_20971520sps.s16.zst_0.02.npy; 103/263 time = 0.0010068970150314271
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658452714_2440000000Hz_20971520sps.s16.zst_0.02.npy; 104/263 time = 0.0016913569997996092
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658454663_2440000000Hz_20971520sps.s16.zst_0.02.npy; 105/263 time = 0.0007372249965555966
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658447656_2420000000Hz_20971520sps.s16.zst_0.02.npy; 106/263 time = 0.0007739479769952595
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658450782_2480000000Hz_20971520sps.s16.zst_0.02.npy; 107/263 time = 0.0007595890201628208
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658444454_242000

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 263/263 [00:00<00:00, 590.64it/s]


loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658448716_2480000000Hz_20971520sps.s16.zst_0.02.npy; 209/263 time = 0.0009431809885427356
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658454583_2420000000Hz_20971520sps.s16.zst_0.02.npy; 210/263 time = 0.0010045860544778407
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658451192_2440000000Hz_20971520sps.s16.zst_0.02.npy; 211/263 time = 0.0006876579718664289
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658453263_2420000000Hz_20971520sps.s16.zst_0.02.npy; 212/263 time = 0.0006861180299893022
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658443914_2420000000Hz_20971520sps.s16.zst_0.02.npy; 213/263 time = 0.0006553049897775054
loading data/gamutrf-pdx/07_21_2022/wifi_2_4/gamutrf_recording_ettus_directional_gain70_1658450226_244000

 38%|██████████████████████████████████████████▌                                                                     | 139/366 [00:00<00:00, 1381.09it/s]

loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658467088_5240000000Hz_20971520sps.s16.zst_0.02.npy; 0/366 time = 0.0013279860140755773
loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658464188_5780000000Hz_20971520sps.s16.zst_0.02.npy; 1/366 time = 0.0011093399953097105
loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658464104_5760000000Hz_20971520sps.s16.zst_0.02.npy; 2/366 time = 0.0006712379981763661
loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658464424_5240000000Hz_20971520sps.s16.zst_0.02.npy; 3/366 time = 0.0006639579660259187
loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658460723_5180000000Hz_20971520sps.s16.zst_0.02.npy; 4/366 time = 0.0006551200058311224
loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658456490_5720000000Hz_20971520sps.s16

 76%|█████████████████████████████████████████████████████████████████████████████████████▊                           | 278/366 [00:00<00:00, 682.55it/s]

loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658464605_5720000000Hz_20971520sps.s16.zst_0.02.npy; 149/366 time = 0.0011942579876631498
loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658458984_5720000000Hz_20971520sps.s16.zst_0.02.npy; 150/366 time = 0.0011505759903229773
loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658462259_5800000000Hz_20971520sps.s16.zst_0.02.npy; 151/366 time = 0.001882120966911316
loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658460435_5760000000Hz_20971520sps.s16.zst_0.02.npy; 152/366 time = 0.0016853170236572623
loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658463986_5740000000Hz_20971520sps.s16.zst_0.02.npy; 153/366 time = 0.0017847910057753325
loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658463690_5000000000Hz_209715

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 366/366 [00:00<00:00, 716.83it/s]


loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658458684_5180000000Hz_20971520sps.s16.zst_0.02.npy; 335/366 time = 0.0020412610028870404
loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658464545_5760000000Hz_20971520sps.s16.zst_0.02.npy; 336/366 time = 0.002045374014414847
loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658462650_5240000000Hz_20971520sps.s16.zst_0.02.npy; 337/366 time = 0.001769432972650975
loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658459344_5000000000Hz_20971520sps.s16.zst_0.02.npy; 338/366 time = 0.0017569800256751478
loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658455865_5160000000Hz_20971520sps.s16.zst_0.02.npy; 339/366 time = 0.0015253899618983269
loading data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658457416_5000000000Hz_2097152

In [4]:
label_dirs = {
    'drone': ['data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/'], 
    #'wifi_2_4': ['data/gamutrf-pdx/07_21_2022/wifi_2_4/'], 
    #'wifi_5': ['data/gamutrf-pdx/07_21_2022/wifi_5/']
}
transform = transforms.Compose([
                transforms.ToTensor(), 
                transforms.Resize((256, 256))
            ])

dataset1 = GamutRFDataset(label_dirs, sample_secs=0.02, nfft=512, transform=transform)
label_dirs = {
    'drone': ['data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/'], 
    #'wifi_2_4': ['data/gamutrf-pdx/07_21_2022/wifi_2_4/'], 
    #'wifi_5': ['data/gamutrf-pdx/07_21_2022/wifi_5/']
}
dataset2 = GamutRFDataset(label_dirs, sample_secs=0.02, nfft=512, transform=transform)

label='drone', 57 files


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 57/57 [00:00<00:00, 443.84it/s]


loading data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/gamutrf_recording_ettus_directional-split_gain70_1655312998_5735000000Hz_20971520sps.s16.zst_0.02.npy; 0/57 time = 0.0024968899670057
loading data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/gamutrf_recording_ettus_directional-split_gain70_1655311950_5735000000Hz_20971520sps.s16.zst_0.02.npy; 1/57 time = 0.0016273150104098022
loading data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/gamutrf_recording_ettus_directional-split_gain70_1655312216_5735000000Hz_20971520sps.s16.zst_0.02.npy; 2/57 time = 0.00189797900384292
loading data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/gamutrf_recording_ettus_directional-split_gain70_1655312014_5735000000Hz_20971520sps.s16.zst_0.02.npy; 3/57 time = 0.0016497069736942649
loading data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/gamutrf_recording_ettus_directional-split_gain70_1655313272_57

  0%|                                                                                                                             | 0/74 [00:00<?, ?it/s]

loading data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/gamutrf_recording_ettus_directional-split_gain45_1653586599_5735000000Hz_20971520sps.s16.zst_0.02.npy; 0/74 time = 0.0018690829747356474
loading data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/gamutrf_recording_ettus_directional-split_gain45_1653591297_5735000000Hz_20971520sps.s16.zst_0.02.npy; 1/74 time = 0.0014649020158685744
loading data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/gamutrf_recording_ettus_directional-split_gain45_1653590704_5735000000Hz_20971520sps.s16.zst_0.02.npy; 2/74 time = 0.0013422199990600348
loading data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/gamutrf_recording_ettus_directional-split_gain70_1653665623_5735000000Hz_20971520sps.s16.zst_0.02.npy; 3/74 time = 0.0006835690001025796
loading data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/gamutrf_recording_ettus_directional-split_

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 74/74 [00:00<00:00, 387.99it/s]


loading data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/gamutrf_recording_ettus_directional-split_gain45_1653588224_5735000000Hz_20971520sps.s16.zst_0.02.npy; 7/74 time = 0.002192755986470729
loading data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/gamutrf_recording_ettus_directional-split_gain45_1653588990_5735000000Hz_20971520sps.s16.zst_0.02.npy; 8/74 time = 0.0015017720288597047
loading data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/gamutrf_recording_ettus_directional-split_gain45_1653591365_5735000000Hz_20971520sps.s16.zst_0.02.npy; 9/74 time = 0.001305958954617381
loading data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/gamutrf_recording_ettus_directional-split_gain45_1653588089_5735000000Hz_20971520sps.s16.zst_0.02.npy; 10/74 time = 0.0013029679539613426
loading data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/gamutrf_recording_ettus_directional-split_g

In [None]:
dataset3 = torch.utils.data.ConcatDataset((dataset1,dataset2))

array([['drone',
        'data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/gamutrf_recording_ettus_directional-split_gain70_1655312998_5735000000Hz_20971520sps.s16.zst',
        '0'],
       ['drone',
        'data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/gamutrf_recording_ettus_directional-split_gain70_1655312998_5735000000Hz_20971520sps.s16.zst',
        '1677720'],
       ['drone',
        'data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/gamutrf_recording_ettus_directional-split_gain70_1655312998_5735000000Hz_20971520sps.s16.zst',
        '3355440'],
       ...,
       ['wifi_5',
        'data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658462560_5160000000Hz_20971520sps.s16.zst',
        '1001598840'],
       ['wifi_5',
        'data/gamutrf-pdx/07_21_2022/wifi_5/gamutrf_recording_ettus_directional_gain70_1658462560_5160000000Hz_20971520sps.s16.zst',
        '1003276560'],
       ['w

In [42]:
a = [i for (i, idx) in enumerate(dataset.idx) if 'leesburg' in idx[1]]
b = [i for (i, idx) in enumerate(dataset.idx) if 'pdx' in idx[1] and 'field' in idx[1]] 
c = [i for (i, idx) in enumerate(dataset.idx) if idx[0] == 'drone']
d = [i for (i, idx) in enumerate(dataset.idx) if not('leesburg' in idx[1] and 'field' in idx[1])] 
print(len(a))
print(len(b))
print(len(a)+len(b))
print(len(c))
print(len(d))
print(len(dataset))
print(len(dataset) - len(a))

34097
110250
144347
144347
487597
521694
487597


In [59]:
train_val_test_split = [0.77, 0.03, 0.20]
all_except_leesburg = [i for (i, idx) in enumerate(dataset.idx) if not('leesburg' in idx[1] and 'field' in idx[1])] 
dataset_sub = torch.utils.data.Subset(dataset, all_except_leesburg)
train_dataset, validation_dataset, test_dataset = torch.utils.data.random_split(dataset_sub, (int(np.ceil(train_val_test_split[0]*len(dataset_sub))), int(np.ceil(train_val_test_split[1]*len(dataset_sub))), int(train_val_test_split[2]*len(dataset_sub))))
just_leesburg = [i for (i, idx) in enumerate(dataset.idx) if 'leesburg' in idx[1]]
leesburg_subset = torch.utils.data.Subset(dataset, just_leesburg)
validation_dataset = torch.utils.data.ConcatDataset((validation_dataset,leesburg_subset))

In [60]:
print(f"{len(train_dataset)=}")
print(f"{len(validation_dataset)=}")
print(f"{len(test_dataset)=}")
print(f"total len = {len(train_dataset)+len(validation_dataset)+len(test_dataset)}")
print(f"{len(all_except_leesburg)=}")
print(f"{len(dataset)=}")

len(train_dataset)=375450
len(validation_dataset)=48725
len(test_dataset)=97519
total len = 521694
len(all_except_leesburg)=487597
len(dataset)=521694


In [61]:
375450/8

46931.25

48725

In [None]:
print(f"{len(dataset1)=}")
print(f"{len(dataset2)=}")
print(f"{len(dataset1)+len(dataset2)=}")
print(f"{len(dataset3)=}")
print(dataset1.class_to_idx)
print(dataset3.class_to_idx)
print(dataset1[0][0])
print(dataset3[34097][0])
print(dataset2[0][0])

In [None]:
train_dataset, validation_dataset, test_dataset = torch.utils.data.random_split(dataset, (int(np.ceil(train_val_test_split[0]*len(dataset))), int(np.ceil(train_val_test_split[1]*len(dataset))), int(train_val_test_split[2]*len(dataset))))

In [None]:
dataset.debug(350)

In [None]:
train_val_test_split = [0.75, 0.05, 0.20]
train_dataset, validation_dataset, test_dataset = torch.utils.data.random_split(dataset, (int(np.ceil(train_val_test_split[0]*len(dataset))), int(np.ceil(train_val_test_split[1]*len(dataset))), int(train_val_test_split[2]*len(dataset))))
print(f"{len(train_dataset)=}")
print(f"{len(validation_dataset)=}")

In [None]:

dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=14)


In [None]:
model_ft = models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
num_ftrs = model_ft.fc.in_features
print(num_ftrs)
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_ft.fc = torch.nn.Linear(num_ftrs, len(dataset.class_to_idx))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
train_dataset, validation_dataset = torch.utils.data.random_split(dataset, (int(np.ceil(0.99995*len(dataset))), int(0.00005*len(dataset))))

validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=1, num_workers=10)

model = model_ft.to(device)
model.load_state_dict(torch.load('resnet18_0.02_1_current.pt'))
model.eval()
predictions = []
labels = []
for i,(data,label) in enumerate(validation_dataloader): 
    print(f"testing {i}/{len(validation_dataloader)}")
    
    data = data.to(device)
    label = label.to(device)
    print(data.shape)
    out = model(data)

    _, preds = torch.max(out, 1)
    predictions.append(preds.item())
    labels.append(label.item())
    correct = preds == label.data
#     print(f"{dataset.idx[i]=}")
#     print(f"out={out.data}")
#     print(f"label={dataset.idx_to_class[label.item()]}, prediction={dataset.idx_to_class[preds.item()]}")
#     print(f"correct={correct.item()}")
    
#     plt.imshow(np.moveaxis(data.squeeze().cpu().numpy(), 0, -1), aspect='auto', origin='lower', cmap=plt.get_cmap('jet'))
#     plt.colorbar()
#     plt.title(f"{dataset.idx_to_class[label.item()]}")
#     plt.show()
# print(f"{labels=}")
# print(f"{predictions=}")
#cm = confusion_matrix(labels, predictions)
disp = ConfusionMatrixDisplay.from_predictions(labels, predictions, display_labels=list(dataset.class_to_idx.keys()), normalize='true')
#disp.plot()
#plt.show()
disp.figure_.savefig('confusion_matrix123123.png')

In [None]:
model = model_ft.to(device)
model.load_state_dict(torch.load('resnet18_0.02_1_current.pt'))
model.eval()

eval_dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
print(dataset.idx_to_class)
for data,label in eval_dataloader: 
    
    data = data.to(device)
    label = label.to(device)
    out = model(data)

    _, preds = torch.max(out, 1)
    correct = preds == label.data
    print(f"out={out}")
    print(f"label={label.item()}, prediction={preds.item()}")
    print(f"correct={correct.item()}")
    
    plt.imshow(np.moveaxis(data.squeeze().cpu().numpy(), 0, -1), aspect='auto', origin='lower', cmap=plt.get_cmap('jet'))
    plt.colorbar()
    plt.title(f"{dataset.idx_to_class[label.item()]}")
    plt.show()