In [19]:
import torch
import torch.nn as nn
# allows text-based summary of model
from torchsummary import summary
# for visual summary of model
from torchviz import make_dot
# can visualise pytorch model architectures with tensorboardX
from tensorboardX import SummaryWriter

### Load model

In [15]:
# Copied CNN class from basic.py
# Ordinarily, this would need to be stored as a class that can be imported
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels= 16, kernel_size= 3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 32 * 32, 128)
        self.fc2 = nn.Linear(128, 3) # 3 possible classes in the dataset

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 32 * 32 * 32)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = CNN()

In [16]:
save_path = '../../models/classification/bean-leaf-lesions/basic.pth'
model.load_state_dict(torch.load(save_path))

<All keys matched successfully>

### Visualise Model Architecture

In [17]:
channels = 3
width = 128
height = 128
summary(model, input_size=(channels, height, width))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 16, 128, 128]             448
         MaxPool2d-2           [-1, 16, 64, 64]               0
            Conv2d-3           [-1, 32, 64, 64]           4,640
         MaxPool2d-4           [-1, 32, 32, 32]               0
            Linear-5                  [-1, 128]       4,194,432
            Linear-6                    [-1, 3]             387
Total params: 4,199,907
Trainable params: 4,199,907
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.19
Forward/backward pass size (MB): 3.75
Params size (MB): 16.02
Estimated Total Size (MB): 19.96
----------------------------------------------------------------


In [18]:
batch_size = 32
sample_input = torch.randn(batch_size, channels, height, width)
output = model(sample_input)

# Gnerate visualisation
dot = make_dot(output, params=dict(model.named_parameters()))
dot.render('model architecture', format='png')

'model architecture.png'

In [20]:
writer = SummaryWriter()

writer.add_graph(model, sample_input)

writer.close()

In [22]:
%load_ext tensorboard
%tensorboard --logdir=runs

Reusing TensorBoard on port 6006 (pid 87588), started 0:00:07 ago. (Use '!kill 87588' to kill it.)