In [1]:
from datetime import datetime
import itertools
import logging
import numpy as np
import os
from sklearn.datasets import make_classification
from sklearn.metrics import classification_report
from skorch import NeuralNetClassifier
from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split
from skorch.callbacks import Checkpoint, TrainEndCheckpoint
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

In [2]:
import scifAI
from scifAI.dl.utils import calculate_weights, train_validation_test_split, get_statistics
from scifAI.dl.dataset import DatasetGenerator
from scifAI.dl.custom_transforms import ShuffleChannel
from scifAI.dl.models import PretrainedModel, resnet18

In [3]:
seed_value = 42

os.environ['PYTHONHASHSEED']=str(seed_value)
import random
random.seed(seed_value)

np.random.seed(seed_value)
torch.manual_seed(seed_value)

<torch._C.Generator at 0x2aaaec10abd0>

In [4]:
%%time

data_path = "/pstore/data/DS4/Apoptotic_cell_detection/"
metadata = scifAI.metadata_generator(data_path)

Metadata prepration starts...
Experiment_1 Donor_1 condition_1


100%|██████████| 15311/15311 [00:02<00:00, 5567.92it/s]


...metadata prepration ended.
CPU times: user 242 ms, sys: 156 ms, total: 398 ms
Wall time: 3.22 s


#### Define all necessary parameters

In [5]:

model_dir = "models"
log_dir = "logs"
scaling_factor = 255.
reshape_size = 32
train_transform = [
         transforms.RandomVerticalFlip(),
         transforms.RandomHorizontalFlip(),
         transforms.RandomRotation(45)
        ]
test_transform = [ ]
num_classes = len(metadata.label.unique())

In [6]:

selected_channels = np.arange(2)
num_channels = len(selected_channels)
channels =np.asarray([ "Ch" + str(i) for i in selected_channels])
num_of_all_channels = len(channels)
all_channels = np.arange(num_of_all_channels)

In [7]:
batch_size = 128
num_workers = 4
device="cuda"
dataset_name = "apoptotic cells"

In [8]:
# hyperparameters for the model
lrscheduler = LRScheduler(
    policy='StepLR', step_size=7, gamma=0.5)
number_epochs = 2
lr = 0.001
momentum=0.9
optimizer = optim.SGD

In [9]:
os.makedirs(model_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)

In [10]:
# initialize logging
now = datetime.now()
timestamp = datetime.timestamp(now)
logging.basicConfig(filename=os.path.join(log_dir, 'remove_and_retrain_{}_{}.txt'.format(dataset_name, timestamp)), level=logging.DEBUG)

#### Load data

In [11]:
train_transform_init = [
         transforms.RandomVerticalFlip(),
         transforms.RandomHorizontalFlip(),
         transforms.RandomRotation(45)
        ]
test_transform_init = [ ]

In [12]:
label_map = dict(zip(sorted(set(metadata["label"])), np.arange(len(set(metadata["label"])))))
num_classes = len(label_map.keys())
class_names_targets = [c for c in label_map.keys()]

In [13]:
def split_load_normalize_data(random_state=seed_value, selected_channels=[]):
    train_index, validation_index, test_index = train_validation_test_split(metadata.index, metadata["label"], random_state=seed_value)
    
    # caclculate statistics
    train_transform = train_transform_init.copy()
    test_transform = test_transform_init.copy()
    train_dataset = DatasetGenerator(metadata=metadata.loc[train_index,:],
                                 label_map=label_map,
                                 selected_channels=selected_channels,
                                 scaling_factor=scaling_factor,
                                 reshape_size=reshape_size,
                                 transform=transforms.Compose(train_transform))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    statistics = get_statistics(train_loader, selected_channels=selected_channels)
    
    # normalize data
    train_transform.append(transforms.Normalize(mean=statistics["mean"],
                         std=statistics["mean"]))
    test_transform.append(transforms.Normalize(mean=statistics["mean"],
                         std=statistics["mean"]))
  
    
    train_dataset = DatasetGenerator(metadata=metadata.loc[train_index,:],
                                 label_map=label_map,
                                 selected_channels=selected_channels,
                                 scaling_factor=scaling_factor, 
                                 reshape_size=reshape_size,
                                 transform= transforms.Compose(train_transform))
    validation_dataset = DatasetGenerator(metadata=metadata.loc[validation_index,:],
                                      label_map=label_map,
                                      selected_channels=selected_channels,
                                      scaling_factor=scaling_factor,
                                      reshape_size=reshape_size,
                                      transform=transforms.Compose(test_transform))
    test_dataset = DatasetGenerator(metadata=metadata.loc[test_index,:],
                                    label_map=label_map,
                                    selected_channels=selected_channels,
                                    scaling_factor=scaling_factor,
                                    reshape_size=reshape_size,
                                    transform=transforms.Compose(test_transform))
    return train_dataset, validation_dataset, test_dataset

In [14]:
def train_model(train_dataset, validation_dataset, num_channels, selected_channels, seed):
    model_saved_name = '{}_net_{}_seed_{}.pth'.format(dataset_name, '_'.join(map(str,selected_channels)), seed)
    checkpoint = Checkpoint(f_params=model_saved_name, monitor='valid_loss_best', dirname='models')
    net = NeuralNetClassifier(
        PretrainedModel, 
        criterion=nn.CrossEntropyLoss,
        lr=lr,
        batch_size=batch_size,
        max_epochs=number_epochs,
        module__output_features=num_classes,
        module__num_classes=num_classes,
        module__num_channels=num_channels, 
        optimizer=optimizer,
        optimizer__momentum=momentum,
        iterator_train__shuffle=False,
        iterator_train__num_workers=num_workers,
        iterator_valid__shuffle=False,
        iterator_valid__num_workers=num_workers,
        callbacks=[lrscheduler, checkpoint],
        train_split=predefined_split(validation_dataset),
        device='cuda' # comment to train on cpu
    )
    net.fit(train_dataset, y=None)
    
    return model_saved_name

In [15]:
def load_and_eval_model(num_channels, test_dataset, path_to_the_cp=""):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    # load model
    model = PretrainedModel(num_classes, num_channels)
    checkpoint = torch.load(os.path.join(model_dir, path_to_the_cp))
    model.load_state_dict(checkpoint)
    model = model.to(device)
    
    # evaluate
    correct = 0.
    total = 0.
    y_true = list()
    y_pred = list()
    y_true_proba = list()
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data[0].to(device).float(), data[1].to(device).long()
            outputs = model(inputs)
            pred = outputs.argmax(dim=1)
            true_proba = np.array([j[i] for (i,j) in zip(pred, outputs.cpu())])
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (labels.reshape(-1) == predicted).sum().item()
            for i in range(len(pred)):
                y_true.append(labels[i].item())
                y_pred.append(pred[i].item())
                y_true_proba.append(true_proba[i].item())
    
    # save result
    logging.info(classification_report(y_true, y_pred, target_names=class_names_targets, digits=4))

In [16]:
def findsubsets(s, n_elements):
    return list(itertools.combinations(s, n_elements))

In [17]:
s = set(all_channels)
n_retrain = 1
for number_removed_channels in np.array([1]):
    all_combinations = findsubsets(s, num_of_all_channels - number_removed_channels)
    for channel_comb in all_combinations:
        for n in range(n_retrain):
            channel_comb = np.asarray(channel_comb)
            logging.info("Train new model: iteration {}, channels: {}".format(str(n), '_'.join(map(str, channel_comb))))
            num_channels = len(channel_comb)
            train_dataset, val_dataset, test_dataset = split_load_normalize_data(random_state=seed_value, selected_channels=channel_comb)
            model_path = train_model(train_dataset, val_dataset, num_channels, channel_comb, n)
            load_and_eval_model(num_channels, test_dataset, model_path)
            os.remove(os.path.join(model_dir, model_path))

100%|██████████| 77/77 [00:11<00:00,  6.97it/s]


statistics used: {'min': tensor([0.]), 'p01': tensor([0.]), 'p05': tensor([0.]), 'p25': tensor([2.9906]), 'p50': tensor([3.1347]), 'p75': tensor([3.1831]), 'p95': tensor([3.4469]), 'p99': tensor([3.7467]), 'max': tensor([5.2018]), 'mean': tensor([2.7733]), 'std': tensor([1.0245])}
  epoch    train_loss    valid_acc    valid_loss    cp      lr     dur
-------  ------------  -----------  ------------  ----  ------  ------
      1        [36m0.7253[0m       [32m0.4796[0m        [35m0.7347[0m     +  0.0010  5.6865
      2        [36m0.6889[0m       [32m0.5649[0m        [35m0.6858[0m     +  0.0010  5.6674


100%|██████████| 77/77 [00:10<00:00,  7.24it/s]


statistics used: {'min': tensor([0.]), 'p01': tensor([0.]), 'p05': tensor([0.]), 'p25': tensor([0.0425]), 'p50': tensor([0.0728]), 'p75': tensor([0.1598]), 'p95': tensor([0.7868]), 'p99': tensor([2.1963]), 'max': tensor([250.7146]), 'mean': tensor([0.2146]), 'std': tensor([1.8042])}
  epoch    train_loss    valid_acc    valid_loss    cp      lr     dur
-------  ------------  -----------  ------------  ----  ------  ------
      1        [36m0.6486[0m       [32m0.6180[0m        [35m0.6354[0m     +  0.0010  5.7036
      2        [36m0.4094[0m       [32m0.6869[0m        0.8362        0.0010  5.7473


In [18]:
channel_comb

array([1])

In [19]:
selected_channels

array([0, 1])

add the running time to ever step for later comparison