This script will take an array of 4k x 4k TMA spots as input and produce patches
for training of the MIL model. Each instance will inherit the label of its bag.

Assumptions for this script to work:
* the following file structure:


ROOT/

	├── tma_spots

    		├── slide_1.png

    		├── slide_2.png

    		└── ...
	
	└── data
			└──{dataname}_bag_labels.csv

			
Output:

{dataname}_train.pytable

{dataname}_val.pytable

in data directory

TODO:
* Save slide-level id and label for each of the input images (TMA spots)

In [None]:
# --- parameters
dataname = 'tma_patches'
patch_size = 256
stride_size = 256
tma_spot_directory = './tma_spots'
test_set_size = 0.1
classes = [0,1]
mirror_pad_size = 128
resize = 1

IMPORTS

In [None]:
import torch
import tables

import os,sys
import glob

import PIL
import numpy as np

import cv2
import matplotlib.pyplot as plt
import csv

from sklearn import model_selection
import sklearn.feature_extraction.image
import random

seed = random.randrange(sys.maxsize) #get a random seed so that we can reproducibly do the cross validation setup
random.seed(seed) # set the seed
print(f"random seed (note down for reproducibility): {seed}")

In [None]:
img_dtype = tables.UInt8Atom()  # dtype in which the images will be saved, this indicates that images will be saved as unsigned int 8 bit, i.e., [0,255]
filenameAtom = tables.StringAtom(itemsize=255)

#get the raw TMA spot files
files = glob.glob(file_directory)

# convert bag_labels csv file into a list of file names and associated labels.
# assumes bag_labels.csv has two columns (slide_id and label). Fname should contain the following set {0, 1, ..., nSPOTS}.
tmp = open('./data/{dataname}_bag_labels.csv', 'r')
bag_labels = csv.DictReader(tmp)
bag_labels = list(bag_labels)

#create training and validation stages and split the files appropriately between them
phases={}
phases["train"],phases["val"]=next(iter(model_selection.ShuffleSplit(n_splits=1,test_size=test_set_size).split(files)))

In [None]:
storage={} #holder for future pytables

block_shape={} #block shape specifies what we'll be saving into the pytable array, here we assume that masks are 1d and images are 3d
block_shape["img"]= np.array((patch_size,patch_size,3))

filters=tables.Filters(complevel=6, complib='zlib') #we can also specify filters, such as compression, to improve storage speed

for phase in phases.keys(): #now for each of the phases, we'll loop through the files
	print(phase)

	totals = np.zeros((2,len(classes)))
	totals[0,:] = classes

	hdf5_file = tables.open_file(f".data/{dataname}_{phase}.pytable", mode='w') #open the respective pytable
	storage["slide_ids"] = hdf5_file.create_earray(hdf5_file.root, 'slide_ids', filenameAtom, (0,)) #create the array for storage

	storage['img'] = hdf5_file.create_earray(hdf5_file.root, 'img', img_dtype,  
                                            shape=np.append([0],block_shape['img']), 
                                            chunkshape=np.append([1],block_shape['img']),
                                            filters=filters)
	storage["labels"]= hdf5_file.create_earray(hdf5_file.root, "labels", img_dtype,  
											shape=[0], 
											chunkshape=[1],
											filters=filters)

	for filei in phases[phase]:
		fname = files[filei]

		print(fname)
		io=cv2.cvtColor(cv2.imread("./imgs/"+os.path.basename(fname).replace("_mask.png",".tif")),cv2.COLOR_BGR2RGB)
		interp_method=PIL.Image.BICUBIC

		io = cv2.resize(io,(0,0),fx=resize,fy=resize, interpolation=interp_method) #resize it as specified above
		io = np.pad(io, [(mirror_pad_size, mirror_pad_size), (mirror_pad_size, mirror_pad_size), (0, 0)], mode="reflect")
		
		#convert input image into overlapping tiles, size is ntiler x ntilec x 1 x patch_size x patch_size x3
		io_arr_out=sklearn.feature_extraction.image.extract_patches(io,(patch_size,patch_size,3),stride_size)

		#resize it into a ntile x patch_size x patch_size x 3
		io_arr_out=io_arr_out.reshape(-1,patch_size,patch_size,3)

		#save the 4d tensor to a table
		storage['img'].append(io_arr_out)

		storage['slide_ids'].append([fname for x in range(io_arr_out.shape[0])])

		storage['label'].append(bag_labels[fname]['label'])
		
		#script to check whether fname matches the entry in the fname column
		print(fname + '\n')
		print(bag_labels[fname]['slide_id'])

	#lastly, we should store the number of pixels
	npixels=hdf5_file.create_carray(hdf5_file.root, 'numpixels', tables.Atom.from_dtype(totals.dtype), totals.shape)
	npixels[:]=totals
	hdf5_file.close()
