In [None]:
import os
import glob
import joblib
import numpy as np
import pandas as pd

from tqdm import tqdm
from fastai.vision.all import *
from utils.fastai_utils import train
from utils.ml_model import get_avgmodel
from utils.heatmap import process_one_tif
from utils.train_patch import get_labeled_df, get_zero_df
from utils.extract_feature_probsmap import get_probsmap_feature

# Need config

In [None]:
train_meta_path = '/home/fm/tissuenet/data/train_meta/train_metadata_eRORy1H.csv'  # train metadata path
train_anno_path = '/home/fm/tissuenet/data/train_meta/train_annotations_lbzOVuS.csv' # train anno path
train_label_path = '/home/fm/tissuenet/data/train_meta/train_labels.csv' # train label path

tif_base_path = '/home/fm/tissuenet/data/train/tif/' # tif base path
model_save_path = './models/' # path to save trained model
heatmap_save_path = './heatmap/' # path to save heatmap(wsi pred result)

patch_model_name = 'patch_model' # fastai model 
wsi_model_name = 'wsi_model.m' # wsi mechine learning model

# base config
read_level = 2  # we choose level 2 to read patch
up_level = 3 # 
down_sample = 256 # downsample ratio is 256 compared with level 2 wsi size
patch_numbers = 40 # we random choose 40 patchs from labeled 0 wsi
infer_bs_size = 16 # we choose batchsize == 16 when inference

lr=2e-2 # training learning rate
epochs=10 # training epochs
img_size=320 # image size
bs_size=16 # training batch size
model = densenet201 # model arch

In [None]:
train_meta = pd.read_csv(train_meta_path)
train_anno = pd.read_csv(train_anno_path)
train_label = pd.read_csv(train_label_path)
zero_list = train_label[train_label['0'] == 1]['filename'].tolist()

file_list = np.array(train_label['filename'].tolist())
label_list = np.argmax(np.array(train_label)[:, 1:], axis=1)

# Step 1: Get train dataframe

We can get all patch to train with certain label from labeled patchs and 0 patchs.

In [None]:
labeled_df = get_labeled_df(train_anno, [], read_level, img_size, base_path=tif_base_path)
zero_df = get_zero_df(zero_list, [], read_level, down_sample, up_level, 
                 patch_numbers, img_size, base_path=tif_base_path)

In [None]:
train_df = labeled_df.append(zero_df)
train_df = train_df.fillna(False)

# Step 2: Patch-level classification train

train a densenet201 classification model

In [None]:
learn = train(train_df, model, lr, epochs, img_size, bs_size)

In [None]:
os.makedirs(model_save_path, exist_ok=True)
learn.model_dir = model_save_path
learn.save(patch_model_name)

# Step 3: Generate heatmap

we extract probsmaps from all train wsi

In [None]:
os.makedirs(heatmap_save_path, exist_ok=True)
model_path = os.path.join(model_save_path, patch_model_name)

In [None]:
wsi_list = glob.glob(tif_base_path+'*.tif')

for item in tqdm(wsi_list):
    result = process_one_tif(item, down_sample, read_level, model_path, model, img_size, infer_bs_size)
    np.save(heatmap_save_path+item.split('/')[-1].split('.')[0], result)

# Step 4: WSI-level classification train

we employ a avg-mechine learning model to classification on wsi.

In [None]:
file_list = list(set([item.split(',')[0].split('/')[-1].split('.')[0] for item in train_df['region'].tolist()]))
label = np.array([train_label[train_label['filename'] == name+'.tif'].iloc[0].tolist().index(1) - 1 for name in file_list])
feature = np.array([get_probsmap_feature(heatmap_save_path+item+'.npy') for item in file_list])

In [None]:
avg_model = get_avgmodel(feature, label)

In [None]:
joblib.dump(avg_model, os.path.join(model_save_path, wsi_model_name))