In [None]:
import tensorflow as tf
from torch.utils.data import Dataset, Dataloader
import torch
import os
from pathlib import Path
from imageio import imread

In [None]:
N = 200

In [None]:
# this is the tensorflow version
def prepare(patient, root_dir, label_info):
    path = Path(os.path.join(root_dir, bytes.decode(patient.numpy())))
    images = [x for x in path if x.is_file()]
    output = []
    for image in images:
        # convert to gray?
        im = imread(image)
        output.append(im)
    label = label_info[patient]
    if (len(output) > N):
        print("Truncating input")
        output = output[:N]
    return np.asarray(output), label

def warp(filename, root_dir, label_info):
    return tf.py_function(prepare, [filename, root_dir, label_info], [tf.float32, tf.float32])
    
def _load_dataset(filenames, root_dir, batch_size, label_dict):
    files = tf.data.Dataset.from_tensor_slices(filenames)
    dataset = files.map(lambda filename: warp(filename, root_dir, label_dict), num_parallel_calls=tf.data.experimental.AUTOTUNE)
    fetchedDS = dataset.prefetch(buffer_size=AUTOTUNE)
    batchedDS = fetchedDS.batch(batch_size, drop_remainder=True)
    return batchedDS


In [None]:
# This is the pytorch version
class tileDataset(Dataset):
    def __init__(self, tile_root_directory, patient_list, label_dict, data_transform):
        super(tileDataset, self).__init__()
        self.label_info = label_dict # a Path object
        self.root_dir = tile_root_directory
        self.patients = patient_list
        self.transform = data_transform
        
    def __getitem(self, idx):
        curr_patient = self.patients[idx]
        # get tiles for this patient
        lookup_dir = self.root_dir/curr_patient
        tile_names = [x for x in lookup_dir if x.is_file()]
        if len(tile_names > N):
            print(f"Input for patient {curr_patient} exceeds limit, truncating..")
            tile_names = tile_names[:N]
        output = []
        for t in tile_names:
            img = imread(t)
            img = self.transform(img).unsqueeze(0)
            output.append(img)
        output = torch.cat(output, dim=0)
        label = self.label_dict[curr_patient]
        return torch.tensor(output).float(), torch.tensor(label).long()
            
            
    def __len__(self):
        return len(self.patients)
            
        


In [None]:
def update_performance_info(curr_loss, performance):
    pass


def train_epoch(dataloader, model, criterion, optimizer, performance_info):
    for i, (X, y) in enumerate(dataloader):
        output = model(X)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        update_performance_info(loss, performance_info)
        if i % 100 == 0:
            print(f"Training batch loss {loss}")
            
    

In [None]:
# TODO
patient_list = None
label_dict = None
batch = 2
data_transform = transforms.Compose([transforms.Resize(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
tiledata = tileDataset('../data/tiles/', patient_list, label_dict, data_transform)
dloader = Dataloader(tiledata, batch_size=batch, shuffle=True)
model = None # TODO
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss() 

In [None]:
# Start training
    