In [None]:
import torch
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

%load_ext autoreload
%autoreload 2

### Data Visualization

In [None]:
image = np.load('data/utek.npy')
label = np.load('data/label.npy')


fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(image)
ax1.axis('off')
ax1.title.set_text('Image')
ax2.imshow(label)
ax2.axis('off')
ax2.title.set_text('Label')
plt.show()

### Model Architecture

In [None]:
class Conv2D(torch.nn.Module):
    def __init__(self, channels, kernel_size, stride=1, padding=None):
        super(Conv2D, self).__init__()

        self.conv = ...

    def forward(self, x):
        x = self.conv(x)
        x = ...
        x = ...
        x = ...
        return x

### Loss Function

In [None]:
def l2_loss(y_pred, y_true):
    """ L2 loss
    :param y_pred: tensor of shape (height, width)
    :param y_true: tensor of shape (height, width)
    :return: scalar
    """
    return ...

### Training Convolutional Detector

In [None]:
import matplotlib
matplotlib.use('Qt5Agg')

# parameters
learning_rate = ...
n_iterations = ...
kernel_size = ...
padding = ...

# creating model
model = Conv2D(channels=..., kernel_size=kernel_size, stride=1, padding=padding)
optimizer = ...

# converting input to tensors
x = torch.tensor(image.transpose(2, 0, 1), dtype=torch.float)
y = torch.tensor(label, dtype=torch.float)

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)

# training loop
for i in range(n_iterations):

    # Compute positions of the joints
    output = model.forward(x)

    # Compute loss
    loss = l2_loss(output, y)

    # Compute gradient
    loss.backward()

    # Make an optimization step and reset the gradient
    optimizer.step()
    optimizer.zero_grad()

    print(f'Iteration: {i}, loss = {loss.detach().numpy()}')
    fig.suptitle(f'Iteration: {i}, loss: {loss.detach().numpy()}', fontsize=16)
    ax1.imshow(image)
    ax1.axis('off')
    ax1.title.set_text('Image')
    ax2.imshow(label)
    ax2.axis('off')
    ax2.title.set_text('Label')
    ax3.imshow(output.cpu().detach().numpy().reshape((200, 200)))
    ax3.axis('off')
    ax3.title.set_text('Prediction')
    kernel = torch.sigmoid(list(model.conv.parameters())[0]).sum(axis=0).squeeze(0).permute(1, 2, 0).cpu().detach().numpy()/3
    ax4.imshow(kernel)
    ax4.axis('off')
    ax4.title.set_text('Kernel')
    plt.show()
    plt.pause(0.1)

# Change matplotlib backend back to inline
%matplotlib inline

In [None]:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15,15))
fig.suptitle(f'Iteration: {i}, loss: {loss.detach().numpy()}', fontsize=16)
ax1.imshow(image)
ax1.axis('off')
ax1.title.set_text('Image')
ax2.imshow(label)
ax2.axis('off')
ax2.title.set_text('Label')
ax3.imshow(output.cpu().detach().numpy().reshape((200, 200)))
ax3.axis('off')
ax3.title.set_text('Prediction')
kernel = torch.sigmoid(list(model.conv.parameters())[0]).sum(axis=0).squeeze(0).permute(1, 2, 0).cpu().detach().numpy()/3
ax4.imshow(kernel)
ax4.axis('off')
ax4.title.set_text('Kernel')
plt.plot()