Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error checks added to ImageNet VGG #377

Merged
merged 8 commits into from
Mar 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions GANDLF/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
efficientnetB6,
efficientnetB7,
)
from .imagenet import (
from .imagenet_vgg import (
imagenet_vgg11,
imagenet_vgg11_bn,
imagenet_vgg13,
imagenet_vgg13_bn,
imagenet_vgg16,
imagenet_vgg19,
imagenet_vgg16_bn,
imagenet_vgg19,
imagenet_vgg19_bn,
)
from .sdnet import SDNet
Expand Down Expand Up @@ -55,9 +59,13 @@
"vgg13": vgg13,
"vgg16": vgg16,
"vgg19": vgg19,
"imagenet_vgg11": imagenet_vgg11,
"imagenet_vgg11_bn": imagenet_vgg11_bn,
"imagenet_vgg13": imagenet_vgg13,
"imagenet_vgg13_bn": imagenet_vgg13_bn,
"imagenet_vgg16": imagenet_vgg16,
"imagenet_vgg19": imagenet_vgg19,
"imagenet_vgg16_bn": imagenet_vgg16_bn,
"imagenet_vgg19": imagenet_vgg19,
"imagenet_vgg19_bn": imagenet_vgg19_bn,
"densenet": densenet264,
"densenet121": densenet121,
Expand Down
81 changes: 0 additions & 81 deletions GANDLF/models/imagenet.py

This file was deleted.

169 changes: 169 additions & 0 deletions GANDLF/models/imagenet_vgg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# -*- coding: utf-8 -*-
"""
Modified from https://github.com/pytorch/vision.git
"""

import torchvision
import torch.nn as nn

from .modelBase import ModelBase


def create_torchvision_model(modelname, pretrained=True, num_classes=2):
if modelname == "vgg11":
model = torchvision.models.vgg11(pretrained=pretrained)
if modelname == "vgg11_bn":
model = torchvision.models.vgg11_bn(pretrained=pretrained)
if modelname == "vgg13":
model = torchvision.models.vgg13(pretrained=pretrained)
if modelname == "vgg13_bn":
model = torchvision.models.vgg13_bn(pretrained=pretrained)
if modelname == "vgg16":
model = torchvision.models.vgg16(pretrained=pretrained)
if modelname == "vgg16_bn":
model = torchvision.models.vgg16_bn(pretrained=pretrained)
if modelname == "vgg19":
model = torchvision.models.vgg19(pretrained=pretrained)
if modelname == "vgg19_bn":
model = torchvision.models.vgg19_bn(pretrained=pretrained)
prev_out_features = model.classifier[3].out_features
model.classifier[6] = nn.Linear(
in_features=prev_out_features, out_features=num_classes
)
return model


class imagenet_vgg11(ModelBase):
def __init__(
self,
parameters,
) -> None:
super(imagenet_vgg11, self).__init__(parameters)
if self.n_dimensions != 2:
raise ValueError("ImageNet pre-trained models only support 2D images")

self.model = create_torchvision_model(
"vgg11", pretrained=True, num_classes=self.n_classes
)

def forward(self, x):
return self.model(x)


class imagenet_vgg11_bn(ModelBase):
def __init__(
self,
parameters,
) -> None:
super(imagenet_vgg11_bn, self).__init__(parameters)
if self.n_dimensions != 2:
raise ValueError("ImageNet pre-trained models only support 2D images")

self.model = create_torchvision_model(
"vgg11_bn", pretrained=True, num_classes=self.n_classes
)

def forward(self, x):
return self.model(x)


class imagenet_vgg13(ModelBase):
def __init__(
self,
parameters,
) -> None:
super(imagenet_vgg13, self).__init__(parameters)
if self.n_dimensions != 2:
raise ValueError("ImageNet pre-trained models only support 2D images")

self.model = create_torchvision_model(
"vgg13", pretrained=True, num_classes=self.n_classes
)

def forward(self, x):
return self.model(x)


class imagenet_vgg13_bn(ModelBase):
def __init__(
self,
parameters,
) -> None:
super(imagenet_vgg13_bn, self).__init__(parameters)
if self.n_dimensions != 2:
raise ValueError("ImageNet pre-trained models only support 2D images")

self.model = create_torchvision_model(
"vgg13_bn", pretrained=True, num_classes=self.n_classes
)

def forward(self, x):
return self.model(x)


class imagenet_vgg16(ModelBase):
def __init__(
self,
parameters,
) -> None:
super(imagenet_vgg16, self).__init__(parameters)
if self.n_dimensions != 2:
raise ValueError("ImageNet pre-trained models only support 2D images")

self.model = create_torchvision_model(
"vgg16", pretrained=True, num_classes=self.n_classes
)

def forward(self, x):
return self.model(x)


class imagenet_vgg16_bn(ModelBase):
def __init__(
self,
parameters,
) -> None:
super(imagenet_vgg16_bn, self).__init__(parameters)
if self.n_dimensions != 2:
raise ValueError("ImageNet pre-trained models only support 2D images")

self.model = create_torchvision_model(
"vgg16_bn", pretrained=True, num_classes=self.n_classes
)

def forward(self, x):
return self.model(x)


class imagenet_vgg19(ModelBase):
def __init__(
self,
parameters,
) -> None:
super(imagenet_vgg19, self).__init__(parameters)
if self.n_dimensions != 2:
raise ValueError("ImageNet pre-trained models only support 2D images")

self.model = create_torchvision_model(
"vgg19", pretrained=True, num_classes=self.n_classes
)

def forward(self, x):
return self.model(x)


class imagenet_vgg19_bn(ModelBase):
def __init__(
self,
parameters,
) -> None:
super(imagenet_vgg19_bn, self).__init__(parameters)
if self.n_dimensions != 2:
raise ValueError("ImageNet pre-trained models only support 2D images")

self.model = create_torchvision_model(
"vgg19_bn", pretrained=True, num_classes=self.n_classes
)

def forward(self, x):
return self.model(x)
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- Added mechanism to perform inference without having access to ground truth labels
- Added mechanism to map output labels using post-processing before saving
- Added mechanism to enable customized histology classification output via heatmaps
- ImageNet pre-trained models added

## 0.0.13

Expand Down
7 changes: 7 additions & 0 deletions testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@
]
# pre-defined regression/classification model types for testing
all_models_classification = [
"imagenet_vgg11",
"imagenet_vgg11_bn",
"imagenet_vgg13",
"imagenet_vgg13_bn",
"imagenet_vgg16",
"imagenet_vgg16_bn",
"imagenet_vgg19",
"imagenet_vgg19_bn",
"resnet18",
]

Expand Down