<a href="https://colab.research.google.com/github/n-west/Wideband-RF-Signal-Detection-with-Machine-Learning/blob/main/3_Spectral_Segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os, sys
import numpy as np
import matplotlib.pyplot as plt
import pylab as pl
from IPython import display

import torch.optim as optim
import torch.nn as nn
import torch

import pandas as pd

import copy
import json
import numpy as np
import datetime
from tqdm import tqdm
from mpl_toolkits.axes_grid1 import make_axes_locatable


In [None]:
!nvidia-smi

The following will download a small simulation dataset from google drive in to your VM and extract it. This will take ~10-15 minutes. You should run this once early during this tutorial.

In [None]:
def the_download():
  !gdown --id 1IYn3-vmqfMThLa_njyZGLi4YNuNnVZrb
  !ls /content/grcon_wideband_train_small.zip
  !mkdir /data
  !unzip /content/grcon_wideband_train_small.zip -d /data/
  !ls /data

try:
  if DATASET_DOWNLOADED:
    print("dataset is already downloaded. skipping")
  else:
    the_download()
    DATASET_DOWNLOADED = True
except:
  the_download()
  DATASET_DOWNLOADED=True

In [None]:
def segnet_labels(anno, nfft, steps, class_map, nClass, single_class=False):
    Y = np.zeros((nfft*steps,nClass+1))

    for a in anno:
        cur_slice = np.zeros((steps,nfft))
        x1 = np.clip(a['x1'], 0, nfft - 1)
        x2 = np.clip(a['x2'], 0, nfft - 1)
        y1 = np.clip(a['y1'], 0, steps - 1)
        y2 = np.clip(a['y2'], 0, steps - 1)

        w = x2 - x1 + 1
        h = y2 - y1 + 1

        if w == 1 and h == 1:
            cur_slice[y1,x1] = 1
        elif w == 1:
            cur_slice[y1:y1+h,x1] = 1
        elif h == 1:
            cur_slice[y1,x1:x1+w] = 1
        else:
            cur_slice[y1:y1+h,x1:x1+w] = 1

        cur_slice = np.ravel(cur_slice)
        l = a['class']

        Y[:,class_map[l]+1] += cur_slice

    # Background
    Y = np.clip(Y,0,1)

    idx = np.where(np.sum(Y, axis = 1)==0)[0]
    Y[idx,0] = 1

    if single_class:
        nidx = np.where(not np.sum(Y, axis = 1)==0)[0]
        Y[nidx,1] = 1
        return Y[:,:,0:2]

    return Y

class spectrogram_generator:
    def __init__(self, basedir, nfft, steps, class_map, batch_size = 16,
                 augment = False, norm = False, file_list = None, network = 'segnet',
                 empty_anno = True, use_pysinc = False, single_class = False):

        self.basedir = basedir
        self.nfft = nfft
        self.steps = steps
        self.class_map = copy.deepcopy(class_map)
        self.batch_size = batch_size
        self.augment = augment
        self.norm = norm
        self.network = network
        self.empty_anno = empty_anno
        self.use_pysinc = use_pysinc
        self.aug = None

        self.file_list = file_list

        if self.file_list is None:
            self.file_list = [f for f in os.listdir(basedir) if f.endswith('sigmf-meta')]

        # Loading the annotations
        colnames = ['description','fc','fcb','fs','f_low','f_high','bw',
                    'start_sample','end_sample','samples','sec','file']

        self.annotations = pd.DataFrame(columns=colnames,dtype='object')

        for k,f in enumerate(self.file_list):
            meta = json.loads(open(os.path.join(basedir,f),'r').read())
            anno = meta['annotations']

            fc = meta['captures'][0]['core:frequency']
            fs = meta['global']['core:sample_rate']

            print('Loading Annotations from: ', f)

            annocount = {}
            for a in anno:
                comment = a.get('core:description', "")
                if single_class:
                    comment = 'detection'
                if(not comment in annocount):
                    annocount[comment] = [0, 0]
                if comment not in self.class_map:  # Can control the labels of the generated sample
                    annocount[comment][1] += 1 # increment the NOT used count
                    continue

                annocount[comment][0] += 1 # increment the USED count

                f_low   = a['core:freq_lower_edge']
                f_high  = a['core:freq_upper_edge']
                startsamp = a['core:sample_start']
                nsamp   = a['core:sample_count']
                bw      = np.abs(f_high-f_low)
                sec     = nsamp/float(fs)
                fcb     = (f_low+f_high)/2.0

                # append to the giant dataframe
                nv = pd.DataFrame([[comment,fc,fcb,fs,f_low,f_high,bw,startsamp,startsamp+nsamp,nsamp,sec,f]],
                                  columns=colnames)
                self.annotations = self.annotations.append(nv)

            for k,v in annocount.items():
                print("%s :: Using %d annotation :: NOT Using %d annotations."%(k, v[0], v[1]))

        self.get_class_map()
        self.inv_map = {v: k for k, v in self.class_map.items()}
        self.nClass = len(self.class_map)
        print('All Annotations Loaded')

    def keras_gen_train(self, maxproc=4):
        while True:
            X, Y = self.generate(augment=True)
            yield X, Y

    def keras_gen_val(self, maxproc=4):
        while True:
            X, Y = self.generate(augment=False)
            yield X, Y

    def get_class_map(self):
      return self.class_map
        # n = len(self.class_map)
        # keys = sorted(self.class_map.keys()) # Making sure class map keys are consistent
        # i = 0
        # for k in keys:
        #     self.class_map[k] = i
        #     i += 1

    def generate(self, batch_size=None,deterministic_file=None, deterministic_offset=None,returnearly=False, augment=False):
        if batch_size is not None:
            self.batch_size=batch_size

        batch_X = np.empty((self.batch_size, self.steps* self.nfft, 2), dtype=np.float32)
        batch_Y = np.empty((self.batch_size, self.nClass+1, self.steps, self.nfft), dtype=np.float32)

        count = 0

        while count < self.batch_size:

            if deterministic_file is None:
                f = np.random.choice(self.file_list)
            else:
                f = deterministic_file
            data_file = f.replace('meta','data')
            Fsamps = np.memmap(os.path.join(self.basedir,data_file),dtype=np.complex64,mode='r')
            anno = self.annotations.loc[self.annotations['file']==f]
            nsamp = self.nfft*self.steps

            if anno.empty:
                # special case to allow empty captures (reduces false positives)
                file_max = len(Fsamps)
                file_min = 0
                offset_start = np.random.randint(max(0,file_min), min(len(Fsamps)-nsamp, file_max) )
                offset_end = offset_start + nsamp
                samps = Fsamps[offset_start:offset_end]
                batch_Y[count, :, :] = 0.0
                batch_Y[count, :, 0] = 1.0

                batch_X[count,:,0] = samps.real
                batch_X[count,:,1] = samps.imag
                count += 1
                continue

            file_min = anno['start_sample'].min()
            file_max = anno['end_sample'].max()

            fs = list(anno['fs'])[0]
            fc = list(anno['fc'])[0]
            (fc0,fs0) = (fc,fs)

            if deterministic_offset is None:
                offset_start = np.random.randint(max(0,file_min),min(len(Fsamps)-nsamp, file_max) )
            else:
                offset_start = deterministic_offset
            offset_end = offset_start+nsamp
            samps = Fsamps[offset_start:offset_end]/np.float32(65536.0)
            if augment:
                samps += np.complex64(np.random.uniform(1e-9, 1e-7) * (np.random.randn(offset_end-offset_start) + 1.j*np.random.randn(offset_end-offset_start)))
            else:
              pass
                # samps += np.complex64(1e-5 * (np.random.randn(offset_end-offset_start) + 1.j*np.random.randn(offset_end-offset_start)))

            rel_anno = anno.loc[anno['end_sample'] > offset_start]
            rel_anno = rel_anno.loc[rel_anno['start_sample'] < offset_end]
            rel_anno = rel_anno.loc[rel_anno['f_high'] > fc-fs/2.0]
            rel_anno = rel_anno.loc[rel_anno['f_low'] < fc+fs/2.0]

            batch_X[count,:,0] = samps.real
            batch_X[count,:,1] = samps.imag

            records = []
            for _,r in rel_anno.iterrows():
                rng = max(offset_start,r['start_sample']),min(offset_end,r['end_sample'])
                rng = list(map(lambda x: ((x-offset_start)/self.nfft), rng))

                band_low = (fc-fs/2.0)
                frac_low  = (r['f_low'] - band_low)/fs
                frac_high = (r['f_high'] - band_low)/fs

                frac_low = max(0,frac_low)
                frac_high = min(1, frac_high)

                assert(frac_low <= 1)
                assert(frac_low >= 0)
                assert(frac_high <= 1)
                assert(frac_high >= 0)

                frng = (frac_low, frac_high)
                frng = list(map(lambda x: (x)*self.nfft, frng))
                records.append({'x1':int(frng[0]), 'x2':int(frng[1]),
                                'y1':int(rng[0]), 'y2':int(rng[1]),
                           'class':r['description']})


            # Annotations goes to Label Generator
            batch_y_data = segnet_labels(records,self.nfft,self.steps,self.class_map, self.nClass)
            batch_y_data = np.transpose(batch_y_data, (1,0)).reshape((self.nClass+1, self.steps, self.nfft))
            batch_Y[count,:,:,:] = batch_y_data

            count += 1

        return batch_X, batch_Y



In [None]:
basedir = '/data/grcon_train/'

labels = []
file_list = [f for f in os.listdir(basedir) if f.endswith('sigmf-meta')]
for fname in file_list:
  with open(basedir + fname, mode="r") as md_f:
    md = json.loads(md_f.read())
    for anno in md["annotations"]:
      if anno["core:description"] not in labels:
        labels.append(anno["core:description"])
print(labels)

class_map = {
            "PSK2": 1,
            "PSK4": 1,
            "PSK8": 1,
            "QAM16": 1,
            "QAM64": 1,
            "QAM256": 1,
            "OOK": 1,
            "FSK2": 1,
            "FSK4": 1,
            "GMSK": 1,
            "OFDM": 1,
            "AM_SSB": 1,
            "AM_DSB": 1,
            "FM": 1,
            }

batch_size = 32
nfft = 512
steps = 128

train_gen = spectrogram_generator(basedir, nfft=nfft, steps=steps, 
                                  class_map=class_map, file_list=file_list,
                                  batch_size=batch_size, norm=True)

nClass = len(train_gen.class_map)+1 # +1 for the background
# nClass = 2

training_generator = train_gen.keras_gen_train()
validation_generator = train_gen.keras_gen_val()

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size=(3, 3)):
        super(DoubleConv, self).__init__()
        self.op = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(output_channels, output_channels, kernel_size=kernel_size, padding=(1, 1)),
            nn.ReLU()
        )

    def forward(self, x):
        return self.op(x)


class OmninetDown(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size=(3, 3)):
        super(OmninetDown, self).__init__()
        self.pool = nn.MaxPool2d(kernel_size=(2, 2))
        self.conv = DoubleConv(input_channels, output_channels, kernel_size=kernel_size)
        self.deepsig_actual_output_size = output_channels + input_channels

    def forward(self, x):
        pool = self.pool(x)
        return pool, torch.cat([pool, self.conv(pool)], 1)


class OmninetUp(nn.Module):
    def __init__(self, input_channels, intermediate_channels, bridge_channels, output_channels, kernel_size=(3, 3)):
        super(OmninetUp, self).__init__()
        self.upsample = nn.ConvTranspose2d(input_channels, intermediate_channels, kernel_size=(2, 2), stride=2, bias=False)
        self.conv = DoubleConv(intermediate_channels+bridge_channels, output_channels, kernel_size=kernel_size)
        self.deepsig_actual_output_channels = output_channels + intermediate_channels

    def forward(self, x, bridge):
        upsampled = self.upsample(x)
        combined = torch.cat([upsampled, bridge], 1)
        conved = self.conv(combined)
        return torch.cat([conved, upsampled], 1)

class UnetInput(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(UnetInput, self).__init__()
        self.firstconv = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=(3, 3), padding=(1, 1)),
            nn.ReLU()
        )
        self.residual = nn.Sequential(
            nn.Conv2d(output_channels, output_channels, kernel_size=(3, 3), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(output_channels, output_channels, kernel_size=(3, 3), padding=(1, 1)),
            nn.ReLU()
        )
        self.deepsig_actual_output_channels =  2*output_channels

    def forward(self, x):
        input_conv = self.firstconv(x)
        return torch.cat([input_conv, self.residual(input_conv)], 1)


class OmninetOutput(nn.Module):
    def __init__(self, input_channels, number_classes):
        super(OmninetOutput, self).__init__()
        self._number_classes = number_classes
        self.op = nn.Sequential(
                nn.ConvTranspose2d(input_channels, 16, kernel_size=(2, 2), stride=2, bias=False),
                nn.Conv2d(16, 16, kernel_size=(3, 3), padding=(1, 1)),
                nn.ReLU(),
                nn.Conv2d(16, number_classes, kernel_size=(3, 3), padding=(1, 1)),
                nn.Softmax(1)
                )

    def forward(self, x):
        return self.op(x)

class Omninet(nn.Module):
    def __init__(self, input_channels, number_classes):
        super(Omninet, self).__init__()
        kernel_size = (3, 3)
        self.head_layer = UnetInput(input_channels, 16)
        self.down_stage1 = OmninetDown(self.head_layer.deepsig_actual_output_channels, 32, kernel_size=kernel_size)
        self.down_stage2 = OmninetDown(self.down_stage1.deepsig_actual_output_size, 64, kernel_size=kernel_size)
        self.down_stage3 = OmninetDown(self.down_stage2.deepsig_actual_output_size, 64, kernel_size=kernel_size)
        self.down_stage4 = OmninetDown(self.down_stage3.deepsig_actual_output_size, 64, kernel_size=kernel_size)

        self.upstage_1 = OmninetUp(self.down_stage4.deepsig_actual_output_size, 64, self.down_stage2.deepsig_actual_output_size, 64, kernel_size=(3, 3))
        self.upstage_2 = OmninetUp(self.upstage_1.deepsig_actual_output_channels, 64, self.down_stage1.deepsig_actual_output_size, 32, kernel_size=(3, 3))
        self.upstage_3 = OmninetUp(self.upstage_2.deepsig_actual_output_channels, 32, self.head_layer.deepsig_actual_output_channels, 16, kernel_size=(3, 3))

        self.tail = OmninetOutput(self.upstage_3.deepsig_actual_output_channels, number_classes)

        self.nfft = 512
        self.steps = 128
        self.win = torch.hann_window(self.nfft).reshape((1,1,self.nfft,1))

    def forward(self, x):
        self.win = self.win.to(x.device)
        samples = x.reshape((-1, self.steps, self.nfft, 2)) * self.win
        samples = torch.view_as_complex(samples)
        psd = torch.fft.fft(samples)
        psd = torch.cat((psd[:,:, int(self.nfft/2):], psd[:, :, :int(self.nfft/2)]), -1).unsqueeze(1)

        psd = torch.log10(1e-9 + torch.abs(psd[:,:,:]) + torch.abs(psd[:,:,:]))
        psd = psd - psd.mean((1,2,3), keepdim=True)
        wf = psd / (1e-9 + torch.std(psd, (1,2,3), True, keepdim=True))

        head = self.head_layer(wf)
        stage1_pool, stage1_out = self.down_stage1(head)
        stage2_pool, stage2_out = self.down_stage2(stage1_out)
        stage3_pool, stage3_out = self.down_stage3(stage2_out)
        stage4_pool, stage4_out = self.down_stage4(stage3_out)

        upstage1 = self.upstage_1(stage4_out, stage3_pool)
        upstage2 = self.upstage_2(upstage1, stage2_pool)
        upstage3 = self.upstage_3(upstage2, stage1_pool)

        tail = self.tail(upstage3)
        return tail, wf


In [None]:
eps = 1e-6
device = "cuda:0"

def wcce(pred, target):
 return torch.mean(-torch.sum(class_W * target * torch.log(eps+pred), 1))

def cce(pred, target):
 return torch.mean(-torch.sum(target * torch.log(eps+pred), 1))

print(nClass)
# Avoid blowing up memory in the event a net already exists with refcounts that won't disappear
try:
  del net
except *:
  pass

net = Omninet(1, nClass)
net.to(device)
criterion = cce
optimizer = optim.Adam(net.parameters(), lr=3e-4)


In [None]:
checker_generator = train_gen.keras_gen_val()
net.to(device)
plt.figure("loss")
train_loss = []
val_loss = []



In [None]:
number_steps = 10
net.to(device)
for epoch in range(200):
    running_train_loss = 0.0
    net.train()
    for step in range(number_steps):
        optimizer.zero_grad()
        train_x_array, train_y_array = next(training_generator)
        train_x = torch.from_numpy(train_x_array).float().to(device)
        
        train_y_reshaped = torch.from_numpy(train_y_array).to(device)

        train_y_hat, _ = net.forward(train_x)
        scoreable_y_hat = train_y_hat.view(-1, nClass, 128, 512)
        loss = criterion(scoreable_y_hat, train_y_reshaped)
        loss.backward()
        optimizer.step()

        running_train_loss += loss.item()

    train_loss.append(running_train_loss / number_steps)
    del loss
    del train_y_hat
    del scoreable_y_hat
    del train_y_reshaped
    
    running_val_loss = 0.0
    net.eval()
    for step in range(number_steps):
        val_x_array, val_y_array = next(validation_generator)
        val_x = torch.from_numpy(val_x_array).float().to(device)
        
        val_y_reshaped = torch.from_numpy(val_y_array).to(device)

        val_y_hat, _ = net(val_x)
        scoreable_y_hat = val_y_hat.view(-1, nClass, 128, 512)
        
        loss = criterion(scoreable_y_hat, val_y_reshaped)
        running_val_loss += loss.item()

    val_loss.append(running_val_loss / number_steps)
    print(running_val_loss)

In [None]:


print(ex_in.shape)
print(target.shape)
# train_x_array, train_y_array = next(validation_generator)
# train_x = torch.from_numpy(train_x_array).float().to(device)

# train_y_reshaped = torch.from_numpy(train_y_array).to(device)

# train_y_hat, _ = net.forward(train_x)
# scoreable_y_hat = train_y_hat.view(-1, nClass, 128, 512)


In [None]:
scoreable_y_hat.shape
print(train_y_array.shape)

In [None]:
ex_in, target = train_gen.generate(augment=False)
inferred_y, wf = net.forward(torch.from_numpy(ex_in).to(device))
plt.figure(figsize=(8,12))
plt.subplot(411)
plt.title("Inferred bg confidence")
plt.imshow(inferred_y[0,0].detach().cpu().view(128, 512))
plt.subplot(412)
plt.title("Inferred class map")
plt.imshow(inferred_y[0].detach().argmax(0).cpu().view(128, 512), vmin=0, vmax=15, cmap="tab20", interpolation="none")
plt.subplot(413)
plt.title("Target class map")
plt.imshow(target[0].argmax(0), vmin=0, vmax=15, cmap="tab20", interpolation="none")
plt.subplot(414)
plt.title("Waterfall")
plt.imshow(wf[0,0].cpu())
