In [67]:
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models
import models_cifar

class ResNetWrapper(nn.Module):
    '''
        ResNet model wrapper for pytorch's official implementation
        - Classifying ImageNet-sized(224x224x3) images.
        Junyoung Park : jy_park@inu.ac.kr
    '''
    def __init__(self, net, n_classes, mode='imagenet', pretrained_weight=None):
        super(ResNetWrapper, self).__init__()
        self.net = net
        self.n_classes = n_classes

        if mode.lower() == 'imagenet':
            print("Wrapper mode : ResNet-ImageNet")
            if self.n_classes != self.net.fc.out_features:
                print(f"- Out channels : {n_classes}")
                self.net.fc = nn.Linear(self.net.fc.in_features, n_classes)
            self._forward = self._forward_imagenet # Use net.fc as FC

        else: # if mode == cifar
            print("Wrapper mode : ResNet-CIFAR")
            if self.n_classes != self.net.linear.out_features:
                print(f"- Out channels : {n_classes}")
                self.net.linear = nn.Linear(self.net.linear.in_features, n_classes)
            self._forward = self._forward_cifar # Use net.linear as FC

        if pretrained_weight != None:
            self._load_pretrained_weight(pretrained_weight)

    def _load_pretrained_weight(self, w_p):
        state_dict = torch.load(w_p)
        print("Load state dict with accuracy : {state_dict['acc']*100:.2f}%")
        print("- Weight dir : {w_p}")
        self.net.load_state_dict(state_dict['net'])

    def _forward_cifar(self, x):
        C1 = F.relu(self.net.bn1(self.net.conv1(x)))
        L1 = self.net.layer1(C1)
        L2 = self.net.layer2(L1)
        L3 = self.net.layer3(L2)
        L4 = self.net.layer4(L3)
        f = F.avg_pool2d(L4, 4)
        f = f.view(f.size(0), -1)
        out = self.net.linear(f)
        return (C1, L1, L2, L3, L4), out

    def _forward_imagenet(self, x):
        C1 = self.net.relu(self.net.bn1(self.net.conv1(x)))
        C1_p = self.net.maxpool(C1)
        L1 = self.net.layer1(C1_p)
        L2 = self.net.layer2(L1)
        L3 = self.net.layer3(L2)
        L4 = self.net.layer4(L3)
        f = self.net.avgpool(L4)
        f = torch.flatten(f, 1)
        out = self.net.fc(f)
        return (C1, L1, L2, L3, L4), out

    def forward(self, x):
        return self._forward(x)

        

# Test on ImageNet-size input

In [68]:
net_t = models.resnet34(pretrained=True)
net_s = models.resnet34(pretrained=False)

net_t = ResNetWrapper(net_t, 1000, mode='imagenet')
net_s = ResNetWrapper(net_s, 1000, mode='imagenet')
x = torch.randn([1,3,224,224])
y_t = net_t(x)
y_s = net_s(x)

# [0][4] : Final Conv Layer.
features_t, pred_t = y_t[0][4], y_t[1]
features_s, pred_s = y_s[0][4], y_s[1]

loss = 0

for f_t, f_s in zip(features_t, features_s):
    loss += nn.MSELoss()(f_t,f_s)

print(loss/len(features_t))



Wrapper mode : ResNet-ImageNet
Wrapper mode : ResNet-ImageNet
tensor(2.9100, grad_fn=<DivBackward0>)


# Test on CIFAR-size input

In [69]:
net_t = models_cifar.ResNet34()
net_s = models_cifar.ResNet34()

net_t = ResNetWrapper(net_t, 10, mode='cifar')
net_s = ResNetWrapper(net_s, 10, mode='cifar')
x = torch.randn([1,3,32,32])
y_t = net_t(x)
y_s = net_s(x)

# [0][4] : Final Conv Layer.
features_t, pred_t = y_t[0][4], y_t[1]
features_s, pred_s = y_s[0][4], y_s[1]

loss = 0

for f_t, f_s in zip(features_t, features_s):
    loss += nn.MSELoss()(f_t,f_s)

print(loss/len(features_t))



Wrapper mode : ResNet-CIFAR
Wrapper mode : ResNet-CIFAR
tensor(2.6612, grad_fn=<DivBackward0>)
