Skip to content

Commit

Permalink
Fix Inception model to match the original
Browse files Browse the repository at this point in the history
This change adapts the weights used by the model to the ones in the
original Tensorflow implementation. The Inception weights available in
torchvision do not correspond to the ones used by Tensorflow's FID.

Furthermore, some subtle differences in the model implementation were
changed to now match the Tensorflow implementation.
  • Loading branch information
mseitzer committed May 27, 2019
1 parent bb3771b commit f64228c
Showing 1 changed file with 175 additions and 9 deletions.
184 changes: 175 additions & 9 deletions inception.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

try:
from torchvision.models.utils import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url

# Inception weights ported to Pytorch from
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'


class InceptionV3(nn.Module):
"""Pretrained InceptionV3 network returning feature maps"""
Expand All @@ -22,7 +32,8 @@ def __init__(self,
output_blocks=[DEFAULT_BLOCK_INDEX],
resize_input=True,
normalize_input=True,
requires_grad=False):
requires_grad=False,
use_fid_inception=True):
"""Build pretrained InceptionV3
Parameters
Expand All @@ -39,11 +50,19 @@ def __init__(self,
layers is fully convolutional, it should be able to handle inputs
of arbitrary size, so resizing might not be strictly needed
normalize_input : bool
If true, normalizes the input to the statistics the pretrained
Inception network expects
If true, scales the input from range (0, 1) to the range the
pretrained Inception network expects, namely (-1, 1)
requires_grad : bool
If true, parameters of the model require gradient. Possibly useful
If true, parameters of the model require gradients. Possibly useful
for finetuning the network
use_fid_inception : bool
If true, uses the pretrained Inception model used in Tensorflow's
FID implementation. If false, uses the pretrained Inception model
available in torchvision. The FID Inception model has different
weights and a slightly different structure from torchvision's
Inception model. If you want to compute FID scores, you are
strongly advised to set this parameter to true to get comparable
results.
"""
super(InceptionV3, self).__init__()

Expand All @@ -57,7 +76,10 @@ def __init__(self,

self.blocks = nn.ModuleList()

inception = models.inception_v3(pretrained=True)
if use_fid_inception:
inception = fid_inception_v3()
else:
inception = models.inception_v3(pretrained=True)

# Block 0: input to maxpool1
block0 = [
Expand Down Expand Up @@ -128,10 +150,7 @@ def forward(self, inp):
align_corners=False)

if self.normalize_input:
x = x.clone()
x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)

for idx, block in enumerate(self.blocks):
x = block(x)
Expand All @@ -142,3 +161,150 @@ def forward(self, inp):
break

return outp


def fid_inception_v3():
"""Build pretrained Inception model for FID computation
The Inception model for FID computation uses a different set of weights
and has a slightly different structure than torchvision's Inception.
This method first constructs torchvision's Inception and then patches the
necessary parts that are different in the FID Inception model.
"""
inception = models.inception_v3(num_classes=1008,
aux_logits=False,
pretrained=False)
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
inception.Mixed_7b = FIDInceptionE_1(1280)
inception.Mixed_7c = FIDInceptionE_2(2048)

state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
inception.load_state_dict(state_dict)
return inception


class FIDInceptionA(models.inception.InceptionA):
"""InceptionA block patched for FID computation"""
def __init__(self, in_channels, pool_features):
super(FIDInceptionA, self).__init__(in_channels, pool_features)

def forward(self, x):
branch1x1 = self.branch1x1(x)

branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)

branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

# Patch: Tensorflow's average pool does not use the padded zero's in
# its average calculation
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
count_include_pad=False)
branch_pool = self.branch_pool(branch_pool)

outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)


class FIDInceptionC(models.inception.InceptionC):
"""InceptionC block patched for FID computation"""
def __init__(self, in_channels, channels_7x7):
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)

def forward(self, x):
branch1x1 = self.branch1x1(x)

branch7x7 = self.branch7x7_1(x)
branch7x7 = self.branch7x7_2(branch7x7)
branch7x7 = self.branch7x7_3(branch7x7)

branch7x7dbl = self.branch7x7dbl_1(x)
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

# Patch: Tensorflow's average pool does not use the padded zero's in
# its average calculation
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
count_include_pad=False)
branch_pool = self.branch_pool(branch_pool)

outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
return torch.cat(outputs, 1)


class FIDInceptionE_1(models.inception.InceptionE):
"""First InceptionE block patched for FID computation"""
def __init__(self, in_channels):
super(FIDInceptionE_1, self).__init__(in_channels)

def forward(self, x):
branch1x1 = self.branch1x1(x)

branch3x3 = self.branch3x3_1(x)
branch3x3 = [
self.branch3x3_2a(branch3x3),
self.branch3x3_2b(branch3x3),
]
branch3x3 = torch.cat(branch3x3, 1)

branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [
self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl),
]
branch3x3dbl = torch.cat(branch3x3dbl, 1)

# Patch: Tensorflow's average pool does not use the padded zero's in
# its average calculation
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
count_include_pad=False)
branch_pool = self.branch_pool(branch_pool)

outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)


class FIDInceptionE_2(models.inception.InceptionE):
"""Second InceptionE block patched for FID computation"""
def __init__(self, in_channels):
super(FIDInceptionE_2, self).__init__(in_channels)

def forward(self, x):
branch1x1 = self.branch1x1(x)

branch3x3 = self.branch3x3_1(x)
branch3x3 = [
self.branch3x3_2a(branch3x3),
self.branch3x3_2b(branch3x3),
]
branch3x3 = torch.cat(branch3x3, 1)

branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [
self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl),
]
branch3x3dbl = torch.cat(branch3x3dbl, 1)

# Patch: The FID Inception model uses max pooling instead of average
# pooling. This is likely an error in this specific Inception
# implementation, as other Inception models use average pooling here
# (which matches the description in the paper).
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)

outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)

0 comments on commit f64228c

Please sign in to comment.