In [1]:
from astrosite_dataset import AstrositeDataset

dataset_path = '../dataset/recordings'
target_id = 37867
dataset = AstrositeDataset(dataset_path, split=str(target_id))

In [2]:
print(dataset[0]['events'][:5])
print(dataset[0]['labelled_events'][:5])

[(  0,  175, 174,  True) (456,  895, 275, False) (568, 1145, 413, False)
 (596, 1093, 246, False) (644,  462, 284,  True)]
[(  0, 175, 174,  True,  0) (728, 177, 173,  True, 26)
 (775, 177, 171,  True, 25) (793, 184, 172,  True,  0)
 (950, 427, 467,  True,  0)]


In [3]:
from spinnaker_loader import EventsLoader

train_loader = EventsLoader(dataset, bins_per_sample=8, sample_time=1000000)

Collect 9 samples that have only 1 satellite label, and only retain the events that are labelled as that satellite.

In [17]:
samples = []
n_samples = 9

for sample in dataset:
    labels = sorted(list(set(sample['labelled_events']['label'])))
    if labels[0] < -1:
        print(f"Skipping sample with labels {labels}")
        continue
    
    events = sample['labelled_events']
    sat_events = events[events['label'] == -1]
    if len(sat_events) > 320:
        samples.append(sat_events)
    if len(samples) >= n_samples:
        break

Skipping sample with labels [-2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38]
Skipping sample with labels [-2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]


In [19]:
len(samples)


9

In [20]:
import numpy as np

side_length = np.sqrt(n_samples).astype(int)
max_x = 1280
max_y = 720

for i in range(side_length):
    for j in range(side_length):
        index = i*side_length+j
        x_shift = max_x*j
        y_shift = max_y*i
        samples[index]['x'] += x_shift
        samples[index]['y'] += y_shift

merged_events = np.concatenate(samples)
merged_events = np.sort(merged_events, order='t')

In [21]:
import event_stream

encoder = event_stream.Encoder('merged_events.es', 'dvs', 1280*side_length, 720*side_length)

In [22]:
encoder.write(merged_events)

In [23]:
target_sat_events = []

for sample in dataset:
    labels = sorted(list(set(sample['labelled_events']['label'])))
    if labels[0] < -1:
        print(f"Skipping sample with labels {labels}")
        continue
    
    events = sample['labelled_events']
    sat_events = events[events['label'] == -1]
    if len(sat_events) > 320:
        target_sat_events.append(sat_events)

Skipping sample with labels [-2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38]
Skipping sample with labels [-2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]


In [24]:
len(target_sat_events)

9