In [176]:
import os
import glob
from random import shuffle
import xml.etree.ElementTree as ET

from numpy import sum
from mrcfile import open as mrc_open
from PIL import Image
from skimage.exposure import equalize_adapthist as clahe
from skimage.transform import rescale
import torch

def norm_minmax(x):
    x -= x.min()
    return x / x.max()

def mrc_to_img(path_to_mrc, scale=0.1):
    with mrc_open(path_to_mrc) as mrc:
        x = mrc.data
    x = sum(x, axis=0)
    x = norm_minmax(x)
    x = clahe(x)
    return rescale(x, scale)

class XMLMetadata:
    def __init__(self, source_dir):
        self.source_dir = source_dir
        self._data = dict()
        for path_to_xml in glob.glob(os.path.join(source_dir, "*.xml")):
            # shared by a .mrc
            name = path_to_xml[path_to_xml.rfind('/')+1:-4]
            self._data[name] = dict()
            # python's xml parser sucks
            metadata = ET.parse(path_to_xml)
            root = metadata.getroot()
            w, h = [int(x.text) for x in root[0][8][2:4]]
            nb_fractions = int(root[0][6].text)
            self._data[name]["shape"] = [nb_fractions,w,h]
            # todo: ... do i need any of this metadata?
            
    def list_files(self):
        return list(self._data.keys())
            
    def __getitem__(self, key):
        return self._data[key]
    
class MRCSampler:
    
    def __init__(self, path_to_mrc, tile_size, tile_stride, scale=1):
        assert scale <= 1, "really don't need to do that"
        self.source_file = path_to_mrc
        self.tile_size = tile_size
        self.tile_stride = tile_stride
        self.img = mrc_to_img(path_to_mrc, scale)
        w,h = self.img.shape
        self.cols = w // tile_stride + 1
        self.rows = h // tile_stride + 1
        self.I = list(range(len(self)))
        shuffle(self.I)
        
    def __getitem__(self, idx):
        row = idx // self.rows
        col = idx % self.rows
        l = row * self.tile_stride
        r = l + self.tile_size
        t = col * self.tile_stride
        b = t + self.tile_size
        return self.img[l:r,t:b]
    
    def get_tile(self, i):
        return self.__getitem__(self.I[i])
    
    def __len__(self):
        return self.cols * self.rows
    
    @classmethod
    def get_factory(cls, tile_dim, stride, scale):
        return lambda path_to_mrc: cls(path_to_mrc, tile_dim, stride, scale=scale)

class MRCData(torch.utils.data.IterableDataset):
    def __init__(self, source_dir, sampler_factory, transform=None, K=1):
        super(MRCData, self).__init__()
        self.transform = transform
        self.K = K
        # method for generating MRCSampler objects
        self._sampler_factory = sampler_factory
        # glob all the mrc files in the source directory
        self.source_dir = source_dir
        self.mrc_files = glob.glob(os.path.join(source_dir, "*.mrc"))
        # load metadata
        self.metadata = XMLMetadata(source_dir)
        
    def _refresh_sampler(self):
        print("loc",self._loc)
        self._sampler = self._sampler_factory(self.mrc_files[self._loc])
        self._loc = (self._loc + 1) % len(self.mrc_files)
        self._idx = 0
        
    def __next__(self):
        # go to the next sampler if this one is depleted
        if self._idx >= len(self._sampler):
            if self.depleted:
                raise StopIteration
            print(f"sampler refresh:\nsampler\t{self._sampler.source_file} {len(self._sampler)}\nidx\t{self._idx}\nloc\t{self._loc}")
            self._refresh_sampler()
            if self._loc == 0:
                self.depleted = True
        # get a randomly-sampled tile from the image
        img = self._sampler.get_tile(self._idx)
        self._idx += 1
        # conditionally generate K augmentations of the image
        # see https://github.com/mpatacchiola/self-supervised-relational-reasoning/
        pic = Image.fromarray(img)
        img_list = list()
        if self.transform is not None:
            for _ in range(self.K):
                img_transformed = self.transform(pic.copy())
                img_list.append(img_transformed)
        else:
            img_list = img
        print(self._idx, self._sampler.source_file)
        return img_list
        
    def __iter__(self):
        shuffle(self.mrc_files)
        self.depleted = False
        self._idx = 0
        self._loc = 0
        self._refresh_sampler()
        if torch.utils.data.get_worker_info() is not None:
            raise NotImplemented("this class doesn't support multiprocess loading (yet)")
        return self
            
    # todo: calculates length from the nb of files, their dimensions, and the sampling parameters
    # def __len__():
    #     assert self.uniform, "cannot calculate __len__ on a non-uniform dataset"

In [177]:
#sampler = MRCSampler("../data/brians_data/0417.mrc", 50, 20, scale = 0.1)
factory = MRCSampler.get_factory(50,20,0.1)
ds = MRCData("../data/brians_data", factory)