Skip to content

Commit

Permalink
Add correct graph labelling
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Jul 5, 2020
1 parent 275c65d commit 639741c
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions factnn/generator/pytorch/eventfile_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@

class EventFileDataset(Dataset):

def __init__(self, root, transform=None, pre_transform=None):
def __init__(self, root, transform=None, pre_transform=None, task="Separation", num_points=0):
"""
:param task: Either 'Separation', 'Energy', or 'Disp'
:param num_points: The number of points to have, either using points multiple times, or subselecting from the total points
"""
super(EventFileDataset, self).__init__(root, transform, pre_transform)
self.processed_filenames = []
self.task = task
self.num_points = num_points

@property
def raw_file_names(self):
Expand Down Expand Up @@ -45,19 +51,30 @@ def process(self):
cx=GEOMETRY.x_angle,
cy=GEOMETRY.y_angle))
# Read data from `raw_path`.
if self.num_points > 0:
if point_cloud.shape[0] < self.num_points:
point_indicies = np.random.choice(point_cloud.shape[0], self.num_points, replace=True)
else:
point_indicies = np.random.choice(point_cloud.shape[0], self.num_points, replace=False)
point_cloud = point_cloud[point_indicies]
data = Data(pos=point_cloud) # Just need x,y,z ignore derived features, padding would in dataloader
if "gamma" in raw_path:
data.event_type = torch.tensor(0, dtype=torch.int8)
elif "proton" in raw_path:
data.event_type = torch.tensor(1, dtype=torch.int8)
else:
print("No Event Type")
continue
data.energy = torch.tensor(event_data[data_format["Energy"]], dtype=torch.float)
data.disp = torch.tensor(euclidean_distance(event_data[data_format['Source_X']], event_data[data_format['Source_Y']],
event_data[data_format['COG_X']], event_data[data_format['COG_Y']]),
data.disp = torch.tensor(true_sign(event_data[data_format['Source_X']],
event_data[data_format['Source_Y']],
event_data[data_format['COG_X']],
event_data[data_format['COG_Y']],
event_data[data_format['Delta']])* euclidean_distance(event_data[data_format['Source_X']],
event_data[data_format['Source_Y']],
event_data[data_format['COG_X']],
event_data[data_format['COG_Y']]),
dtype=torch.float16)
data.sign = torch.tensor(true_sign(event_data[data_format['Source_X']], event_data[data_format['Source_Y']],
event_data[data_format['COG_X']], event_data[data_format['COG_Y']],
event_data[data_format['Delta']]), dtype=torch.uint8)

if self.pre_filter is not None and not self.pre_filter(data):
continue

Expand All @@ -73,5 +90,10 @@ def len(self):

def get(self, idx):
data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
if self.task == "Energy":
data.y = data.energy
elif self.task == "Disp":
data.y = data.disp
else:
data.y = data.event_type
return data

0 comments on commit 639741c

Please sign in to comment.