In [5]:
import torch
import torch.nn as nn

In [None]:
def conv_block(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 
              nn.BatchNorm2d(out_channels), 
              nn.ReLU(inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

class ResNet(ImageClassificationBase):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        
        self.input = conv_block(in_channels, 64)

        self.conv1 = conv_block(64, 64, pool=True)
        self.res1 = nn.Sequential(conv_block(64, 32), conv_block(32, 64))
        self.drop1 = nn.Dropout(0.5)
        
        self.conv2 = conv_block(64, 64, pool=True)
        self.res2 = nn.Sequential(conv_block(64, 32), conv_block(32, 64))
        self.drop2 = nn.Dropout(0.5)
        
        self.conv3 = conv_block(64, 64, pool=True)
        self.res3 = nn.Sequential(conv_block(64, 32), conv_block(32, 64))
        self.drop3 = nn.Dropout(0.5)
        
        self.classifier = nn.Sequential(nn.MaxPool2d(6), 
                                        nn.Flatten(),
                                        nn.Linear(64, num_classes))
        
    def forward(self, xb):
        out = self.input(xb)

        out = self.conv1(out)
        out = self.res1(out) + out
        out = self.drop1(out)
        
        out = self.conv2(out)
        out = self.res2(out) + out
        out = self.drop2(out)
        
        out = self.conv3(out)
        out = self.res3(out) + out
        out = self.drop3(out)
        
        return self.classifier(out)

In [2]:
model = torch.load('./emotion_detection_model_state.pth')

In [4]:
print(model)

OrderedDict([('input.0.weight', tensor([[[[ 0.2494, -0.0034,  0.2216],
          [ 0.1971, -0.2936,  0.0223],
          [-0.0363,  0.0380, -0.2277]]],


        [[[-0.0606,  0.0403,  0.0382],
          [-0.1639,  0.1317, -0.1804],
          [-0.1820,  0.1667,  0.2255]]],


        [[[-0.1639, -0.1973,  0.0439],
          [ 0.2643, -0.2236, -0.1309],
          [ 0.3653, -0.2487,  0.2781]]],


        [[[-0.1749,  0.3234, -0.0825],
          [ 0.0117,  0.3079, -0.1664],
          [ 0.1067, -0.1964, -0.2400]]],


        [[[-0.1158, -0.0045, -0.3111],
          [ 0.0352,  0.2439,  0.0125],
          [ 0.0747,  0.2052,  0.2016]]],


        [[[-0.3230, -0.0058,  0.2669],
          [ 0.2193, -0.1752, -0.0845],
          [-0.2942,  0.1748,  0.2249]]],


        [[[-0.2498, -0.1165, -0.2146],
          [ 0.1291,  0.2320,  0.2339],
          [-0.0683,  0.1054, -0.0764]]],


        [[[ 0.2483,  0.1778,  0.3301],
          [-0.1449,  0.2764, -0.2523],
          [-0.2811, -0.2458, -0.1143]]],


