In [12]:
#Dataloaders for the 4DeNoise project

In [13]:
#Importing libraries
import hyperspy.api as hs
import torch
import matplotlib.pyplot as plt
#%matplotlib widget

In [14]:
#Settings the inference area i.e. the area that will be uses as input to the NN
# 2 before, 2 after
inference_H=1
inference_W=5
x_offset=2
y_offset=0
"""
#For surrounding 8 instead
inference_H=3
inference_W=3
x_offset=1
y_offset=1
"""
#Coords of input pixels relative to output pixel. In [y,x] (numpy) order
inputs_coords=[]
for y in range(inference_H):
    for x in range(inference_W):
        coords=[x-x_offset,y-y_offset]
        if coords[0]!=0 or coords[1]!=0: #i.e. coords is not [0,0]
            inputs_coords.append(coords)

In [15]:
#Custom dataset object
class DataSet(torch.utils.data.Dataset):
    def __init__(self, file_paths):
        #file_paths here is a list of paths refers to a list of paths to files that are used as sources of data
        self.imgs=[]
        for file_path in file_paths:
            self.imgs.append(hs.load(file_path, reader="hspy"))

    #Height and width
    def img_H(self,img_index):
        return self.imgs[img_index].data.shape[2]

    def img_W(self,img_index):
        return self.imgs[img_index].data.shape[3]

    def index_location(self, index): #FInds a location in i, y, x (i being the img_index) of pixel number index
        if index>self.__len__():
            raise ValueError("Index too high")

        running_total=0
        for img_index in range(len(self.imgs)):
            new_running_total=running_total+((self.img_H(img_index)+1-inference_H)*(self.img_W(img_index)+1-inference_W))
            if index<new_running_total: #It's in this image
                difference=index-running_total
                x_pos=difference%self.img_W(img_index)
                y_pos=difference//self.img_W(img_index)
                return img_index, y_pos, x_pos
            else:
                running_total=new_running_total

    def __len__(self):
        running_total=0
        for img_index in range(len(self.imgs)):
            running_total+=(self.img_H(img_index)+1-inference_H)*(self.img_W(img_index)+1-inference_W)
        return running_total
    
    #Function that returns input/output pair 
    def getitem(self, index):

        img_index, y_pos, x_pos=self.index_location(index)
        
        item_input=torch.tensor(self.imgs[img_index].data[y_pos,x_pos],dtype = torch.float64)
        item_output=[]
        for coords in inputs_coords:
            item_output.append(self.imgs[img_index].data[y_pos+coords[0],x_pos+coords[1]])
        item_output=torch.tensor(item_output,dtype = torch.float64)
        
        return item_input,item_output
    
    def __getitem__(self,index):
        return self.getitem(index)
    

In [None]:
#Change file location to that of noisy data. Right now, I am assuemd a hspy format. You can use a list of multiple 4D images
testing_dataset=DataSet([r"C:\Users\m03855jw\Downloads\4D-STEM_data_for_anthracene\4D-STEM_data_for_anthracene\Mg31872\20221020_211713_data_binned2.hdf5"])

