In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from torch.autograd import Variable

plt.ion()   # interactive mode

In [7]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3, 1, 1) #6*224*224
        #pool
        self.conv2 = nn.Conv2d(6, 16, 3, 1, 1) #16*112*112
        self.pool = nn.MaxPool2d(2, 2) 
        self.conv3 = nn.Conv2d(16, 32, 3, 1, 1) #32*56*56
        #pool
        
        self.fc1 = nn.Linear(32 * 28 * 28, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))

        x = x.view(-1, 32 * 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

In [8]:
net = torch.load("./save")

In [9]:
net.eval()

Net(
  (conv1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=25088, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=2, bias=True)
)

In [10]:
net.state_dict()

OrderedDict([('conv1.weight', tensor([[[[-0.2441, -0.2778, -0.2062],
                        [-0.0332, -0.0190, -0.0407],
                        [-0.1217, -0.0447,  0.1399]],
              
                       [[-0.2346, -0.1734,  0.0174],
                        [-0.1237,  0.0874, -0.1303],
                        [ 0.2598,  0.2698,  0.1834]],
              
                       [[-0.1711,  0.2013, -0.0108],
                        [ 0.2390,  0.1699,  0.2012],
                        [ 0.3731,  0.4240,  0.2861]]],
              
              
                      [[[-0.1657, -0.3149, -0.3521],
                        [-0.0936, -0.0921, -0.2579],
                        [-0.0253, -0.3142, -0.0846]],
              
                       [[ 0.0237, -0.0080, -0.4019],
                        [ 0.1336, -0.0912, -0.2492],
                        [ 0.0127, -0.2033, -0.3646]],
              
                       [[ 0.4443,  0.3079,  0.1879],
                        [ 0.3296,  0.270