In [1]:
import numpy as np
import torch
import os
import matplotlib.pyplot as plt
from h5py import Dataset
from jupyter_server.utils import fetch

In [2]:
mount_point = "/home/felipe/ExternalDrives"

In [7]:
from torch.utils.data import Dataset
from typing import List, Tuple, Iterable

class MILDataset(Dataset):
    """
    Subclass MILDataset and implement the fetch_tiles method for your dataset.
    
    fetch_tiles
      input: str -> unique identifier for each bag
      output: tuple(tiles, label) -> a set of image data tensors (C,W,H) and their shared label
      
    When creating DataLoaders, use the MILDataset.collate method as your collate function.
    This is necessary because bags have a variable number of tiles.    
    """
    def __init__(
            self,
            bag_ids: List[str],
        ):
        super().__init__()
    
        self.bag_ids = {
            i: bag for i, bag in enumerate(bag_ids)
        }

    
    def __len__(self):
        return len(self.bag_ids)
    
    def __getitem__(self, i):
        bag = self.bag_ids[i]
        tiles, label = self.fetch_tiles(bag)  
        return torch.tensor(i), tiles, label
          
    @staticmethod
    def fetch_tiles(bag: str, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError
    
    @staticmethod
    def collate(
            batch: Iterable[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        batch: tuple -> bag, tiles, label 
        shapes 
          bag -> 1
          tiles -> N, C, W, H
          label -> 1
        """
        
        batch_bags = []
        batch_tiles = []
        batch_labels = []
        
        for sample in batch:
            batch_bags.append(sample[0])
            batch_tiles.append(sample[1])
            batch_labels.append(sample[2])
        
        collated_bags = torch.cat(batch_bags, dim=0)
        collated_tiles = torch.cat(batch_tiles, dim=1)
        collated_labels = torch.stack(batch_labels)
        
        return collated_bags, collated_tiles, collated_labels