In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)

        self.fc1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.out = nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)

        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)

        x = x.reshape(-1, 16 * 5 * 5)
        x = self.fc1(x)
        x = F.relu(x)

        x = self.fc2(x)
        x = F.relu(x)

        x = self.out(x)
        return x

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet().to(device)
torch.save(model.state_dict(), "lenet_init.pth")

x = torch.ones((1, 3, 32, 32)).to(device)
print(model(x))

tensor([[-0.0875,  0.1503, -0.0599,  0.0087, -0.0239,  0.0358,  0.1110,  0.0427,
         -0.0074,  0.0625]], device='cuda:0', grad_fn=<AddmmBackward0>)
