In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [15]:
nn.Linear?

In [2]:
class DenseNet(nn.Module):
    def __init__(self, encoder_widths, decoder_widths, act_fn=nn.ReLU(), out_fn=None):
        super(DenseNet, self).__init__()
        
        assert encoder_widths[-1] == decoder_widths[0], "encoder output and decoder input dims must match"

        enc_layers = {}
        for k in range(len(encoder_widths) - 1):
            enc_layers[f"enc_layer_{k}"] = nn.Linear(encoder_widths[k], encoder_widths[k + 1])
        self.enc_layers = nn.ModuleDict(enc_layers)

        dec_layers = {}
        for k in range(len(decoder_widths) - 1):
            dec_layers[f"dec_layer_{k}"] = nn.Linear(decoder_widths[k], decoder_widths[k + 1])
        self.dec_layers = nn.ModuleDict(dec_layers)

        self.act_fn = act_fn
        self.out_fn = out_fn

    def forward(self, x):
        
        for k in range(len(self.enc_layers)):
            x = self.enc_layers[f'enc_layer_{k}'](x)
            x = x if k == len(self.enc_layers) - 1 else self.act_fn(x)

        for k in range(len(self.dec_layers)):
            x = self.dec_layers[f'dec_layer_{k}'](x)
            if k == len(self.dec_layers) - 1:
                x = x if self.out_fn is None else self.out_fn(x)
            else:
                x = self.act_fn(x)
        return x

In [16]:
net = DenseNet(encoder_widths = [784, 200,100,8], decoder_widths = [8,100,200,784])

net.enc_layers['enc_layer_2'](net.enc_layers['enc_layer_1'](net.enc_layers['enc_layer_0'](inputs))).shape


torch.Size([4, 8])

In [18]:


net = DenseNet(encoder_widths = [784, 200,100,8], decoder_widths = [8,100,200,784])
criterion = nn.CrossEntropyLoss()

trainset = torchvision.datasets.MNIST(
    root="./data/mnist", train=True, download=True, transform=transforms.ToTensor()
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
running_loss = 0

for epoch in range(2):
    for i, data in enumerate(trainloader, 0):

        inputs, labels = data
        
        inputs = nn.Flatten()(inputs)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()


        # print statistics
        running_loss += loss.item()
        if i % 100 == 0:    # print every 100 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0


[1,     1] loss: 0.003
[1,   101] loss: 0.327
[1,   201] loss: 0.308
[1,   301] loss: 0.206
[1,   401] loss: 0.124
[1,   501] loss: 0.121
[1,   601] loss: 0.123
[1,   701] loss: 0.120
[1,   801] loss: 0.118
[1,   901] loss: 0.121
[1,  1001] loss: 0.118
[1,  1101] loss: 0.118
[1,  1201] loss: 0.117
[1,  1301] loss: 0.117
[1,  1401] loss: 0.113
[1,  1501] loss: 0.111
[1,  1601] loss: 0.108
[1,  1701] loss: 0.101
[1,  1801] loss: 0.091
[1,  1901] loss: 0.090
[1,  2001] loss: 0.087
[1,  2101] loss: 0.081
[1,  2201] loss: 0.080
[1,  2301] loss: 0.079
[1,  2401] loss: 0.076
[1,  2501] loss: 0.076
[1,  2601] loss: 0.068
[1,  2701] loss: 0.071
[1,  2801] loss: 0.063
[1,  2901] loss: 0.054
[1,  3001] loss: 0.055
[1,  3101] loss: 0.049
[1,  3201] loss: 0.046
[1,  3301] loss: 0.049
[1,  3401] loss: 0.045
[1,  3501] loss: 0.038
[1,  3601] loss: 0.043
[1,  3701] loss: 0.041
[1,  3801] loss: 0.032
[1,  3901] loss: 0.043
[1,  4001] loss: 0.035
[1,  4101] loss: 0.031
[1,  4201] loss: 0.037
[1,  4301] 

In [29]:
inputs.view(-1, 784).shape

torch.Size([4, 784])

In [32]:
for i, data in enumerate(trainloader, 0):

    inputs, labels = data
         
    inputs = nn.Flatten()(inputs)

    net(inputs)

KeyError: 'dec_layer_0'

In [15]:
nn.Linear?


<generator object Module.parameters at 0x7f9e7ab95660>