Skip to content

Commit

Permalink
Support calculating features from Imagnet models
Browse files Browse the repository at this point in the history
- Features (layer activations) can now be generated from Imagenet-pretrained (rather than only Slideflow-trained) models, by passing a loaded Tensorflow/Keras or Torch model to `DatasetFeatures` (rather than the path to a saved model).
`Project.generate_features_for_clam()` now accepts the name of an architecture to the `model` argument, so features can be calculated from a pretrained (but not Slideflow-trained and saved) model. If providing a model architecture name rather than the path to a saved model, a Dataset must be explicitly passed to `dataset`. E.g.: `Project.generate_features_for_clam(model='xception', dataset=...)`
- Improved CLAM error function when there are insufficient tiles per slide (min_tiles=8)
  • Loading branch information
jamesdolezal committed Jul 8, 2022
1 parent f6811d7 commit 86d9c5d
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 109 deletions.
57 changes: 34 additions & 23 deletions slideflow/clam/models/model_clam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from slideflow.clam.utils import initialize_weights

"""
Expand All @@ -11,7 +10,7 @@
L: input feature dimension
D: hidden layer dimension
dropout: whether to use dropout (p = 0.25)
n_classes: number of classes
n_classes: number of classes
"""
class Attn_Net(nn.Module):

Expand All @@ -25,9 +24,9 @@ def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):
self.module.append(nn.Dropout(0.25))

self.module.append(nn.Linear(D, n_classes))

self.module = nn.Sequential(*self.module)

def forward(self, x):
return self.module(x), x # N x n_classes

Expand All @@ -37,15 +36,15 @@ def forward(self, x):
L: input feature dimension
D: hidden layer dimension
dropout: whether to use dropout (p = 0.25)
n_classes: number of classes
n_classes: number of classes
"""
class Attn_Net_Gated(nn.Module):
def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):
super(Attn_Net_Gated, self).__init__()
self.attention_a = [
nn.Linear(L, D),
nn.Tanh()]

self.attention_b = [nn.Linear(L, D),
nn.Sigmoid()]
if dropout:
Expand All @@ -54,7 +53,7 @@ def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):

self.attention_a = nn.Sequential(*self.attention_a)
self.attention_b = nn.Sequential(*self.attention_b)

self.attention_c = nn.Linear(D, n_classes)

def forward(self, x):
Expand All @@ -71,17 +70,17 @@ def forward(self, x):
dropout: whether to use dropout
k_sample: number of positive/neg patches to sample for instance-level training
dropout: whether to use dropout (p = 0.25)
n_classes: number of classes
n_classes: number of classes
instance_loss_fn: loss function to supervise instance-level training
subtyping: whether it's a subtyping problem
"""
class CLAM_SB(nn.Module):
def __init__(self, gate = True, size_arg = "small", dropout = False, k_sample=8, n_classes=2,
instance_loss_fn=nn.CrossEntropyLoss(), subtyping=False):
super(CLAM_SB, self).__init__()

self.size_dict = {"small": [1024, 512, 256], "big": [1024, 512, 384] , "multiscale": [2048, 512, 256], 'xception':[2048,256,128], 'xception_multi': [1880,128,64], 'xception_3800': [3800, 512, 256]}

if type(size_arg) == str:
size = self.size_dict[size_arg]
else:
Expand Down Expand Up @@ -111,20 +110,26 @@ def relocate(self):
self.attention_net = self.attention_net.to(device)
self.classifiers = self.classifiers.to(device)
self.instance_classifiers = self.instance_classifiers.to(device)

@staticmethod
def create_positive_targets(length, device):
return torch.full((length, ), 1, device=device).long()
@staticmethod
def create_negative_targets(length, device):
return torch.full((length, ), 0, device=device).long()

#instance-level evaluation for in-the-class attention branch
def inst_eval(self, A, h, classifier):
def inst_eval(self, A, h, classifier):
device=h.device
if len(A.shape) == 1:
A = A.view(1, -1)
top_p_ids = torch.topk(A, self.k_sample)[1][-1]
try:
top_p_ids = torch.topk(A, self.k_sample)[1][-1]
except RuntimeError as e:
raise RuntimeError(
f"Error selecting top_k: {e}. Verify that all slides have "
"at least 8 tiles (min_tiles=8)."
)
top_p = torch.index_select(h, dim=0, index=top_p_ids)
top_n_ids = torch.topk(-A, self.k_sample, dim=1)[1][-1]
top_n = torch.index_select(h, dim=0, index=top_n_ids)
Expand All @@ -137,13 +142,19 @@ def inst_eval(self, A, h, classifier):
all_preds = torch.topk(logits, 1, dim = 1)[1].squeeze(1)
instance_loss = self.instance_loss_fn(logits, all_targets)
return instance_loss, all_preds, all_targets

#instance-level evaluation for out-of-the-class attention branch
def inst_eval_out(self, A, h, classifier):
device=h.device
if len(A.shape) == 1:
A = A.view(1, -1)
top_p_ids = torch.topk(A, self.k_sample)[1][-1]
try:
top_p_ids = torch.topk(A, self.k_sample)[1][-1]
except RuntimeError as e:
raise RuntimeError(
f"Error selecting top_k: {e}. Verify that all slides have "
"at least 8 tiles (min_tiles=8)."
)
top_p = torch.index_select(h, dim=0, index=top_p_ids)
p_targets = self.create_negative_targets(self.k_sample, device)
logits = classifier(top_p)
Expand All @@ -153,7 +164,7 @@ def inst_eval_out(self, A, h, classifier):

def forward(self, h, label=None, instance_eval=False, return_features=False, attention_only=False):
device = h.device
A, h = self.attention_net(h) # NxK
A, h = self.attention_net(h) # NxK
A = torch.transpose(A, 1, 0) # KxN
if attention_only:
return A
Expand Down Expand Up @@ -183,13 +194,13 @@ def forward(self, h, label=None, instance_eval=False, return_features=False, att

if self.subtyping:
total_inst_loss /= len(self.instance_classifiers)
M = torch.mm(A, h)

M = torch.mm(A, h)
logits = self.classifiers(M)
Y_hat = torch.topk(logits, 1, dim = 1)[1]
Y_prob = F.softmax(logits, dim = 1)
if instance_eval:
results_dict = {'instance_loss': total_inst_loss, 'inst_labels': np.array(all_targets),
results_dict = {'instance_loss': total_inst_loss, 'inst_labels': np.array(all_targets),
'inst_preds': np.array(all_preds)}
else:
results_dict = {}
Expand Down Expand Up @@ -229,7 +240,7 @@ def __init__(self, gate = True, size_arg = "small", dropout = False, k_sample=8,

def forward(self, h, label=None, instance_eval=False, return_features=False, attention_only=False):
device = h.device
A, h = self.attention_net(h) # NxK
A, h = self.attention_net(h) # NxK
A = torch.transpose(A, 1, 0) # KxN
if attention_only:
return A
Expand Down Expand Up @@ -260,14 +271,14 @@ def forward(self, h, label=None, instance_eval=False, return_features=False, att
if self.subtyping:
total_inst_loss /= len(self.instance_classifiers)

M = torch.mm(A, h)
M = torch.mm(A, h)
logits = torch.empty(1, self.n_classes).float().to(device)
for c in range(self.n_classes):
logits[0, c] = self.classifiers[c](M[c])
Y_hat = torch.topk(logits, 1, dim = 1)[1]
Y_prob = F.softmax(logits, dim = 1)
if instance_eval:
results_dict = {'instance_loss': total_inst_loss, 'inst_labels': np.array(all_targets),
results_dict = {'instance_loss': total_inst_loss, 'inst_labels': np.array(all_targets),
'inst_preds': np.array(all_preds)}
else:
results_dict = {}
Expand Down
62 changes: 42 additions & 20 deletions slideflow/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections import defaultdict
from math import isnan
from os.path import exists, join
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -27,6 +27,10 @@
from slideflow.util import log
from tqdm import tqdm

if TYPE_CHECKING:
import tensorflow as tf
import torch

# --- Backend-specific imports ------------------------------------------------

if os.environ['SF_BACKEND'] == 'tensorflow':
Expand Down Expand Up @@ -149,7 +153,7 @@ class DatasetFeatures:

def __init__(
self,
model: Path,
model: Union[str, "tf.keras.models.Model", "torch.nn.Module"],
dataset: "sf.Dataset",
annotations: Optional[Labels] = None,
cache: Optional[str] = None,
Expand Down Expand Up @@ -187,16 +191,23 @@ def __init__(
self.annotations = annotations
self.model = model
self.dataset = dataset
self.tile_px = dataset.tile_px
self.tfrecords = np.array(dataset.tfrecords())
self.slides = sorted([sf.util.path_to_name(t) for t in self.tfrecords])
model_config = sf.util.get_model_config(model)
self.tile_px = model_config['tile_px']
self.hp = ModelParams.from_dict(model_config['hp'])
self.normalizer = self.hp.get_normalizer()
if self.normalizer:
log.info(f'Using realtime {self.normalizer.method} normalization')
if 'norm_fit' in model_config:
self.normalizer.fit(**model_config['norm_fit'])

# Load configuration if model is path to a saved model
if isinstance(model, str):
model_config = sf.util.get_model_config(model)
hp = ModelParams.from_dict(model_config['hp'])
self.uq = hp.uq
self.normalizer = hp.get_normalizer()
if self.normalizer:
log.info(f'Using realtime {self.normalizer.method} normalization')
if 'norm_fit' in model_config:
self.normalizer.fit(**model_config['norm_fit'])
else:
self.normalizer = None
self.uq = False

if self.annotations:
self.categories = list(set(self.annotations.values()))
Expand Down Expand Up @@ -269,7 +280,7 @@ def __init__(

def _generate_from_model(
self,
model: Path,
model: Union[str, "tf.keras.models.Model", "torch.nn.Module"],
layers: Union[str, List[str]] = 'postconv',
include_logits: bool = True,
include_uncertainty: bool = True,
Expand Down Expand Up @@ -300,17 +311,28 @@ def _generate_from_model(
layers = sf.util.as_list(layers)

# Load model
if self.hp.uq and include_uncertainty:
feat_kw = dict(
layers=layers,
include_logits=include_logits
)
if self.uq and include_uncertainty:
combined_model = sf.model.UncertaintyInterface(
model,
layers=layers
)
else:
combined_model = sf.model.Features( # type: ignore
elif isinstance(model, str):
combined_model = sf.model.Features(model, **feat_kw)
elif sf.backend() == 'tensorflow':
combined_model = sf.model.Features.from_model(model, **feat_kw)
elif sf.backend() == 'torch':
combined_model = sf.model.Features.from_model(
model,
layers=layers,
include_logits=include_logits
tile_px=self.tile_px,
**feat_kw
)
else:
raise ValueError(f'Unrecognized model {model}')

self.num_features = combined_model.num_features
self.num_logits = 0 if not include_logits else combined_model.num_logits

Expand Down Expand Up @@ -375,7 +397,7 @@ def batch_worker():
batch_loc = np.stack([batch_loc[0], batch_loc[1]], axis=1)

# Process model outputs
if self.hp.uq and include_uncertainty:
if self.uq and include_uncertainty:
uncertainty = model_out[-1]
model_out = model_out[:-1]
else:
Expand All @@ -395,7 +417,7 @@ def batch_worker():
self.activations[slide].append(batch_act[d])
if include_logits:
self.logits[slide].append(logits[d])
if self.hp.uq and include_uncertainty:
if self.uq and include_uncertainty:
self.uncertainty[slide].append(uncertainty[d])
self.locations[slide].append(batch_loc[d])

Expand Down Expand Up @@ -605,11 +627,11 @@ def to_torch(
)
torch.save(slide_activations, join(outdir, f'{slide}.pt'))
args = {
'model': self.model,
'model': self.model if isinstance(self.model, str) else '<NA>',
'num_features': self.num_features
}
sf.util.write_json(args, join(outdir, 'settings.json'))
log.info('Activations exported in Torch format.')
log.info(f'Activations exported in Torch format to {outdir}')

def to_df(
self
Expand Down
41 changes: 21 additions & 20 deletions slideflow/model/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,27 @@ def call(self, inputs, **kwargs):
class ModelParams(_base._ModelParams):
"""Build a set of hyperparameters."""

ModelDict = {
'xception': kapps.Xception,
'vgg16': kapps.VGG16,
'vgg19': kapps.VGG19,
'resnet50': kapps.ResNet50,
'resnet101': kapps.ResNet101,
'resnet152': kapps.ResNet152,
'resnet50_v2': kapps.ResNet50V2,
'resnet101_v2': kapps.ResNet101V2,
'resnet152_v2': kapps.ResNet152V2,
'inception': kapps.InceptionV3,
'nasnet_large': kapps.NASNetLarge,
'inception_resnet_v2': kapps.InceptionResNetV2,
'mobilenet': kapps.MobileNet,
'mobilenet_v2': kapps.MobileNetV2,
# 'ResNeXt50': kapps.ResNeXt50,
# 'ResNeXt101': kapps.ResNeXt101,
# 'DenseNet': kapps.DenseNet,
# 'NASNet': kapps.NASNet
}

def __init__(self, *args, **kwargs):
self.OptDict = {
'Adam': tf.keras.optimizers.Adam,
Expand All @@ -53,26 +74,6 @@ def __init__(self, *args, **kwargs):
'Adamax': tf.keras.optimizers.Adamax,
'Nadam': tf.keras.optimizers.Nadam
}
self.ModelDict = {
'xception': kapps.Xception,
'vgg16': kapps.VGG16,
'vgg19': kapps.VGG19,
'resnet50': kapps.ResNet50,
'resnet101': kapps.ResNet101,
'resnet152': kapps.ResNet152,
'resnet50_v2': kapps.ResNet50V2,
'resnet101_v2': kapps.ResNet101V2,
'resnet152_v2': kapps.ResNet152V2,
'inception': kapps.InceptionV3,
'nasnet_large': kapps.NASNetLarge,
'inception_resnet_v2': kapps.InceptionResNetV2,
'mobilenet': kapps.MobileNet,
'mobilenet_v2': kapps.MobileNetV2,
# 'ResNeXt50': kapps.ResNeXt50,
# 'ResNeXt101': kapps.ResNeXt101,
# 'DenseNet': kapps.DenseNet,
# 'NASNet': kapps.NASNet
}
if hasattr(kapps, 'EfficientNetV2B0'):
self.ModelDict.update({'efficientnet_v2b0': kapps.EfficientNetV2B0})
if hasattr(kapps, 'EfficientNetV2B1'):
Expand Down
Loading

0 comments on commit 86d9c5d

Please sign in to comment.