In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision.transforms as transforms
import torchvision.datasets as datasets

import matplotlib.pyplot as plt


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 3, 3, 1, 1)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(3, 6, 3, 1, 1)
        self.pool2 = nn.MaxPool2d(2)
        
        self.conv_trans1 = nn.ConvTranspose2d(6, 3, 4, 2, 1)
        self.conv_trans2 = nn.ConvTranspose2d(3, 1, 4, 2, 1)
        
    def forward(self, x):
        x = F.relu(self.pool1(self.conv1(x)))
        x = F.relu(self.pool2(self.conv2(x)))        
        x = F.relu(self.conv_trans1(x))
        x = self.conv_trans2(x)
        return x


In [3]:
dataset = datasets.MNIST(
    root='PATH',
    download=True,
    transform=transforms.ToTensor()
)
loader = DataLoader(
    dataset,
    num_workers=2,
    batch_size=8,
    shuffle=True
)


In [4]:
model = MyModel()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

epochs = 1
for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, data)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 1000 == 0:
            print('Epoch {}, Batch idx {}, loss {}'.format(epoch, batch_idx, loss.item()))

Epoch 0, Batch idx 0, loss 0.7588210105895996
Epoch 0, Batch idx 1000, loss 0.1379169225692749
Epoch 0, Batch idx 2000, loss 0.11430930346250534
Epoch 0, Batch idx 3000, loss 0.08831126987934113
Epoch 0, Batch idx 4000, loss 0.07700944691896439
Epoch 0, Batch idx 5000, loss 0.07604709267616272
Epoch 0, Batch idx 6000, loss 0.08636042475700378
Epoch 0, Batch idx 7000, loss 0.06384305655956268


In [5]:
def normalize_output(img):
    img = img - img.min()
    img = img / img.max()
    return img

In [6]:
# Plot some images
idx = torch.randint(0, output.size(0), ())
pred = normalize_output(output[idx, 0])
img = data[idx, 0]

In [8]:
img_data = img.detach().numpy()

In [None]:
fig, axarr = plt.subplots(1, 2)
axarr[0].imshow(img_data)

In [1]:
axarr[1].imshow(pred.detach().numpy())

NameError: name 'axarr' is not defined

In [None]:
# Visualize feature maps
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

In [None]:
model.conv1.register_forward_hook(get_activation('conv1'))
data, _ = dataset[0]
data.unsqueeze_(0)
output = model(data)

act = activation['conv1'].squeeze()
fig, axarr = plt.subplots(act.size(0))
for idx in range(act.size(0)):
    axarr[idx].imshow(act[idx])