In [22]:
import os
import numpy as np

import torch
from torch.utils.data import DataLoader
from torchvision import transforms

from src.datasets import PatchDataset, PatchGridDataset
from src.gridnet_patches import GridNet, GridNetHex
from src.densenet import DenseNet
from src.training import train_gnet_2stage, train_gnet_finetune

In [13]:
# Transforms for pre-processing of image data
patch_size = 256
xform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(patch_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Used for gradient checkpointing -- process image arrays in 32 patch chunks to reduce memory usage!
atonce_patch_limit = 32

# Training GridNet on Cartestian ST data

In [20]:
# Generate train/validation data sets
imgs_train = os.path.expanduser("~/Desktop/aba_stdataset_20200212/imgs256_train")
lbls_train = os.path.expanduser("~/Desktop/aba_stdataset_20200212/lbls256_train")
imgs_val = os.path.expanduser("~/Desktop/aba_stdataset_20200212/imgs256_val")
lbls_val = os.path.expanduser("~/Desktop/aba_stdataset_20200212/lbls256_val")

h_st, w_st = 35, 33  # Height and width of ST array
n_class = 13  # 13 distinct foreground tissue classes


# Dataset of all (dissociated) foreground patches -- for pre-training of patch classifier.
patch_train = PatchDataset(imgs_train, lbls_train, xform)
patch_val = PatchDataset(imgs_val, lbls_val, xform)

x,y = patch_train[0]
print("Patch shape:", str(x.shape))

# Dataset of all image arrays -- for training of GridNet.
grid_train = PatchGridDataset(imgs_train, lbls_train, xform)
grid_val = PatchGridDataset(imgs_val, lbls_val, xform)

x,y = grid_train[0]
print("Grid shape:", str(x.shape))


# Data Loaders - present (shuffled) batches of input/output pairs to training routines
batch_size = 1
patch_loaders = {
    "train": DataLoader(patch_train, batch_size=32, shuffle=True, pin_memory=True),
    "val": DataLoader(patch_val, batch_size=32, shuffle=True, pin_memory=True)
}
grid_loaders = {
    "train": DataLoader(grid_train, batch_size=batch_size, shuffle=True, pin_memory=True),
    "val": DataLoader(grid_val, batch_size=batch_size, shuffle=True, pin_memory=True)
}

Patch shape: torch.Size([3, 256, 256])
Grid shape: torch.Size([32, 49, 3, 256, 256])


In [None]:
# Model formulation employed in our publication

# Local (patch) classifier
f = DenseNet(num_classes=n_class, small_inputs=False, efficient=False,
    growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0)
# Global (grid) corrector
g = GridNet(f, patch_shape=(3,patch_size,patch_size), grid_shape=(h_st, w_st), n_classes=n_class, 
    use_bn=False, atonce_patch_limit=atonce_patch_limit)


# Perform fitting with randomly sampled learning rate and alpha
lr = 10 ** (np.random.uniform(-4,-3))
alpha = np.random.random() * 0.1

print("Learning Rate: %.4g" % lr)
print("Alpha: %.4g" % alpha)

train_gnet_2stage(g, [patch_loaders, grid_loaders], lr, alpha=alpha, num_epochs=100)

# Training GridNetHex on Visium ST data