In [1]:
import torch as th
import torch.nn as nn
#from network import Network
import torch.nn.functional as F
#from pruningmethod import PruningMethod

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 50, 5)
        self.fc1 = nn.Linear(50 * 4 * 4, 800)
        self.fc2 = nn.Linear(800, 500)
        self.fc3 = nn.Linear(500, 10)
        self.a_type='relu'
        #for m in self.modules():
        #    self.weight_init(m)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        layer1 = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        layer2 = F.max_pool2d(F.relu(self.conv2(layer1)), 2)
        layer2_p = layer2.view(-1, int(layer2.nelement() / layer2.shape[0]))
        layer3= F.relu(self.fc1(layer2_p))
        layer4 = F.relu(self.fc2(layer3))
        layer5 = self.fc3(layer4)
        return layer5

In [2]:
import torch

# Load the checkpoint file
checkpoint_path = "base.pth"
checkpoint = torch.load(checkpoint_path)

# Inspect the contents of the checkpoint
print("Checkpoint keys:", checkpoint.keys())

# If it's a model state_dict, print the keys of the state_dict
if 'state_dict' in checkpoint:
    print("\nState_dict keys:", checkpoint['state_dict'].keys())

# If the checkpoint is a state_dict itself (most common case)
if isinstance(checkpoint, dict):
    print("\nVariables in the checkpoint:")
    for key, value in checkpoint.items():
        print(f"{key}: {type(value)}")


Checkpoint keys: dict_keys(['model', 'train_acc', 'test_acc', 'optimizer', 'scheduler'])

Variables in the checkpoint:
model: <class 'collections.OrderedDict'>
train_acc: <class 'float'>
test_acc: <class 'float'>
optimizer: <class 'dict'>
scheduler: <class 'dict'>


In [3]:
# Initialize your model
model = LeNet()

# Load the model state dictionary
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [4]:
from torchvision import datasets, transforms
trainloader = th.utils.data.DataLoader(datasets.MNIST('../data',
                                                          download=True,
                                                          train=True,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              #transforms.Normalize((0.5,), (0.5,)) # normalize inputs
                                                          ])),
                                           batch_size=100,
                                           shuffle=True,num_workers=0)

# download and transform test dataset
testloader = th.utils.data.DataLoader(datasets.MNIST('../data',
                                                          download=True,
                                                          train=False,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              #transforms.Normalize((0.5,), (0.5,)) # normalize inputs
                                                          ])),
                                           batch_size=100,
                                           shuffle=True,num_workers=0)

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in testloader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f"Test Accuracy: {test_accuracy:.2f}%")

Test Accuracy: 99.02%


In [6]:
from thop import profile

# Input size for the model (e.g., (batch_size, num_channels, height, width))
input = torch.randn(1, 1, 28, 28)

# Calculate FLOPs and parameters
flops, params = profile(model, inputs=(input,))

print(f"FLOPs: {flops}")
print(f"Parameters: {params}")

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
FLOPs: 2933000.0
Parameters: 1071880.0
