In [1]:
import os
import numpy as np 
import zstandard
from tqdm import tqdm
import torch
from scipy import signal
import matplotlib
import matplotlib.pyplot as plt
from timeit import default_timer as timer
import torchvision
from torchvision import datasets, models, transforms
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from gamutRF.gamutrf_dataset import * 

from gamutrf.sample_reader import get_reader
from gamutrf.utils import parse_filename 

In [2]:
label_dirs= {
    'drone': ['data/gamutrf-birdseye-field-days/leesburg_field_day_2022_06_15/worker1/','data/gamutrf-birdseye-field-days/pdx_field_day_2022_05_26/worker1/gamutrf/'], 
    'wifi_2_4': ['data/gamutrf-pdx/07_21_2022/wifi_2_4/'], 
    'wifi_5': ['data/gamutrf-pdx/07_21_2022/wifi_5/']
}
sample_secs = 0.02
nfft = 512
batch_size = 8
num_workers = 19


dataset = GamutRFDataset(label_dirs, sample_secs=sample_secs, nfft=nfft)

train_val_test_split = [0.77, 0.03, 0.20]
all_except_leesburg = [i for (i, idx) in enumerate(dataset.idx) if not('leesburg' in idx[1] and 'field' in idx[1])] 
dataset_sub = torch.utils.data.Subset(dataset, all_except_leesburg)
train_dataset, validation_dataset, test_dataset = torch.utils.data.random_split(dataset_sub, (int(np.ceil(train_val_test_split[0]*len(dataset_sub))), int(np.ceil(train_val_test_split[1]*len(dataset_sub))), int(train_val_test_split[2]*len(dataset_sub))))
just_leesburg = [i for (i, idx) in enumerate(dataset.idx) if 'leesburg' in idx[1]]
leesburg_subset = torch.utils.data.Subset(dataset, just_leesburg)
validation_dataset = torch.utils.data.ConcatDataset((validation_dataset,leesburg_subset))

label='drone', 131 files
label='wifi_2_4', 263 files
label='wifi_5', 366 files


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 131/131 [00:00<00:00, 475.83it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 263/263 [00:00<00:00, 643.49it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 366/366 [00:00<00:00, 919.54it/s]


In [3]:
model = models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
model.fc = torch.nn.Linear(model.fc.in_features, len(dataset.class_to_idx))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.load_state_dict(torch.load('resnet18_leesburg_split_0.02_1_current.pt'))
model.eval()


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [7]:
leesburg_dataloader = torch.utils.data.DataLoader(leesburg_subset, batch_size=batch_size, num_workers=num_workers)
print(dataset.idx_to_class)
i = 0 
total = 0
total_correct = 0 
for data,label,idx in leesburg_dataloader: 
    
    data = data.to(device)
    label = label.to(device)
    out = model(data)

    _, preds = torch.max(out, 1)
    correct = torch.sum(preds == label.data)
    total_correct += correct
    total += len(preds)
    print(f"avg correct = {total_correct/total}")
#     print(f"out={out.cpu().detach().numpy()}")
#     print(f"label={dataset.idx_to_class[label.item()]}, prediction={dataset.idx_to_class[preds.item()]}")
#     print(f"correct={correct.item()}")
#     print(f"{idx}")
    
#     plt.imshow(np.moveaxis(data.squeeze().cpu().numpy(), 0, -1), aspect='auto', origin='lower', cmap=plt.get_cmap('jet'))
#     plt.colorbar()
#     plt.title(f"{dataset.idx_to_class[label.item()]}")
#     plt.show()
#     i += 1 
#     if i > 100: 
#         break

{0: 'drone', 1: 'wifi_2_4', 2: 'wifi_5'}
avg correct = 0.125
avg correct = 0.25
avg correct = 0.1666666716337204
avg correct = 0.15625
avg correct = 0.125
avg correct = 0.1041666716337204
avg correct = 0.0892857164144516
avg correct = 0.078125
avg correct = 0.0694444477558136
avg correct = 0.0625
avg correct = 0.07954545319080353
avg correct = 0.0729166716337204
avg correct = 0.08653846383094788
avg correct = 0.0803571492433548
avg correct = 0.0833333358168602
avg correct = 0.09375
avg correct = 0.10294117778539658
avg correct = 0.118055559694767
avg correct = 0.125
avg correct = 0.125
avg correct = 0.125
avg correct = 0.11931818723678589
avg correct = 0.1195652186870575
avg correct = 0.1302083432674408
avg correct = 0.13499999046325684
avg correct = 0.13461539149284363
avg correct = 0.12962962687015533
avg correct = 0.1294642984867096
avg correct = 0.125
avg correct = 0.12083333730697632
avg correct = 0.11693547666072845
avg correct = 0.11328125
avg correct = 0.10984848439693451
avg c

avg correct = 0.6450381875038147
avg correct = 0.6463878154754639
avg correct = 0.6472538113594055
avg correct = 0.6485849022865295
avg correct = 0.6499060392379761
avg correct = 0.6512172222137451
avg correct = 0.652518630027771
avg correct = 0.6538103818893433
avg correct = 0.6550925970077515
avg correct = 0.6563653349876404
avg correct = 0.6576286554336548
avg correct = 0.6588827967643738
avg correct = 0.6601277589797974
avg correct = 0.6613636612892151
avg correct = 0.6625905632972717
avg correct = 0.6638086438179016
avg correct = 0.6650180220603943
avg correct = 0.666218638420105
avg correct = 0.6674107313156128
avg correct = 0.6685943007469177
avg correct = 0.6697694659233093
avg correct = 0.6709364056587219
avg correct = 0.6720950603485107
avg correct = 0.6732456088066101
avg correct = 0.6743881106376648
avg correct = 0.6755226254463196
avg correct = 0.6766493320465088
avg correct = 0.6777681708335876
avg correct = 0.6788793206214905
avg correct = 0.6799828410148621
avg correct 

avg correct = 0.7695841789245605
avg correct = 0.7700289487838745
avg correct = 0.7702312469482422
avg correct = 0.7701923251152039
avg correct = 0.770393431186676
avg correct = 0.7708333134651184
avg correct = 0.7710325121879578
avg correct = 0.7714694738388062
avg correct = 0.7714285850524902
avg correct = 0.7718631029129028
avg correct = 0.7722960114479065
avg correct = 0.7727273106575012
avg correct = 0.7731569409370422
avg correct = 0.7735849022865295
avg correct = 0.7737758755683899
avg correct = 0.7742011547088623
avg correct = 0.7743902206420898
avg correct = 0.7745786309242249
avg correct = 0.7750000357627869
avg correct = 0.7754197716712952
avg correct = 0.7758380174636841
avg correct = 0.7762546539306641
avg correct = 0.7766697406768799
avg correct = 0.7768518328666687
avg correct = 0.7772643566131592
avg correct = 0.7776752710342407
avg correct = 0.7778545022010803
avg correct = 0.7782628536224365
avg correct = 0.7786697149276733
avg correct = 0.7790750861167908
avg correct

avg correct = 0.752607524394989
avg correct = 0.7517904043197632
avg correct = 0.7516254782676697
avg correct = 0.7506493330001831
avg correct = 0.75
avg correct = 0.7496761679649353
avg correct = 0.7496765851974487
avg correct = 0.75
avg correct = 0.7498387098312378
avg correct = 0.7501610517501831
avg correct = 0.75
avg correct = 0.7498393058776855
avg correct = 0.7498395442962646
avg correct = 0.7496795058250427
avg correct = 0.75
avg correct = 0.7496802806854248
avg correct = 0.7498403191566467
avg correct = 0.75
avg correct = 0.75
avg correct = 0.7501590251922607
avg correct = 0.7504764795303345
avg correct = 0.7507931590080261
avg correct = 0.7511090040206909
avg correct = 0.7514240741729736
avg correct = 0.7517383098602295
avg correct = 0.751893937587738
avg correct = 0.7520492076873779
avg correct = 0.7522040009498596
avg correct = 0.7525157332420349
avg correct = 0.7528266310691833
avg correct = 0.7531367540359497
avg correct = 0.753446102142334
avg correct = 0.753754734992981

avg correct = 0.7578960061073303
avg correct = 0.7571601867675781
avg correct = 0.7565470337867737
avg correct = 0.7560561895370483
avg correct = 0.7554453015327454
avg correct = 0.7549564838409424
avg correct = 0.7543478012084961
avg correct = 0.7537403106689453
avg correct = 0.7530134916305542
avg correct = 0.7522880434989929
avg correct = 0.7515640258789062
avg correct = 0.7508413195610046
avg correct = 0.7503601908683777
avg correct = 0.750239908695221
avg correct = 0.7497603297233582
avg correct = 0.749401330947876
avg correct = 0.7491626739501953
avg correct = 0.7490440011024475
avg correct = 0.7483285665512085
avg correct = 0.7478530406951904
avg correct = 0.7473784685134888
avg correct = 0.7467857003211975
avg correct = 0.746194064617157
avg correct = 0.745603621006012
avg correct = 0.7453703284263611
avg correct = 0.7452561855316162
avg correct = 0.7450236678123474
avg correct = 0.7450284361839294
avg correct = 0.7451514005661011
avg correct = 0.7453922629356384
avg correct = 

avg correct = 0.7776274681091309
avg correct = 0.7774131298065186
avg correct = 0.7772955298423767
avg correct = 0.776792585849762
avg correct = 0.7764830589294434
avg correct = 0.7761739492416382
avg correct = 0.7762500047683716
avg correct = 0.7760376334190369
avg correct = 0.7755376100540161
avg correct = 0.7749423980712891
avg correct = 0.7743481397628784
avg correct = 0.7738506197929382
avg correct = 0.7734494805335999
avg correct = 0.7728577256202698
avg correct = 0.7723624110221863
avg correct = 0.7717723846435547
avg correct = 0.7714694738388062
avg correct = 0.7709763646125793
avg correct = 0.7705792188644409
avg correct = 0.7701827883720398
avg correct = 0.7700722813606262
avg correct = 0.7700570225715637
avg correct = 0.7697568535804749
avg correct = 0.7693621516227722
avg correct = 0.7687784433364868
avg correct = 0.7684799432754517
avg correct = 0.7678977251052856
avg correct = 0.7673164010047913
avg correct = 0.7667359709739685
avg correct = 0.7662509679794312
avg correct

avg correct = 0.7425144910812378
avg correct = 0.7426801919937134
avg correct = 0.7428456544876099
avg correct = 0.743010938167572
avg correct = 0.7431759834289551
avg correct = 0.7433408498764038
avg correct = 0.7435054779052734
avg correct = 0.7435897588729858
avg correct = 0.7437540292739868
avg correct = 0.7438380122184753
avg correct = 0.7440019249916077
avg correct = 0.7441655993461609
avg correct = 0.744329035282135
avg correct = 0.7444922924041748
avg correct = 0.7446553707122803
avg correct = 0.7448182106018066
avg correct = 0.7449808716773987
avg correct = 0.7451432943344116
avg correct = 0.7453054785728455
avg correct = 0.7454675436019897
avg correct = 0.7456293702125549
avg correct = 0.745790958404541
avg correct = 0.7459523677825928
avg correct = 0.7457963228225708
avg correct = 0.7455611824989319
avg correct = 0.7453263401985168
avg correct = 0.7450917959213257
avg correct = 0.7448576092720032
avg correct = 0.7444654703140259
avg correct = 0.7441530227661133
avg correct =

avg correct = 0.7170835733413696
avg correct = 0.7172390222549438
avg correct = 0.7173942923545837
avg correct = 0.7175493836402893
avg correct = 0.7176358103752136
avg correct = 0.7177906036376953
avg correct = 0.7179452180862427
avg correct = 0.7180312275886536
avg correct = 0.718185544013977
avg correct = 0.718339741230011
avg correct = 0.7184937000274658
avg correct = 0.7186475396156311
avg correct = 0.7188011407852173
avg correct = 0.7189546823501587
avg correct = 0.7190398573875427
avg correct = 0.7190566658973694
avg correct = 0.7190054655075073
avg correct = 0.7189542055130005
avg correct = 0.7191072106361389
avg correct = 0.7191240191459656
avg correct = 0.719140887260437
avg correct = 0.719293475151062
avg correct = 0.7194459438323975
avg correct = 0.7194625735282898
avg correct = 0.7196147441864014
avg correct = 0.7196990251541138
avg correct = 0.7197831869125366
avg correct = 0.7199349999427795
avg correct = 0.7200866341590881
avg correct = 0.7202380895614624
avg correct = 

KeyboardInterrupt: 

In [None]:


eval_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=1, shuffle=True, num_workers=num_workers)
print(dataset.idx_to_class)
i = 0 
for data,label,idx in eval_dataloader: 
    
    data = data.to(device)
    label = label.to(device)
    out = model(data)

    _, preds = torch.max(out, 1)
    correct = preds == label.data
    print(f"out={out.cpu().detach().numpy()}")
    print(f"label={dataset.idx_to_class[label.item()]}, prediction={dataset.idx_to_class[preds.item()]}")
    print(f"correct={correct.item()}")
    print(f"{idx}")
    
    plt.imshow(np.moveaxis(data.squeeze().cpu().numpy(), 0, -1), aspect='auto', origin='lower', cmap=plt.get_cmap('jet'))
    plt.colorbar()
    plt.title(f"{dataset.idx_to_class[label.item()]}")
    plt.show()
    i += 1 
    if i > 100: 
        break