In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader

# Paths

In [5]:
data_root = '../../data/'

# Dataset

In [6]:
# Let's look at some data
img = plt.imread(os.path.join(data_root, '2016-05-26_st015_000303_000_000_000_n_00_crop_000.tif'))
plt.imshow(img, cmap='gray');

FileNotFoundError: [Errno 2] No such file or directory: '../../data/2016-05-26_st015_000303_000_000_000_n_00_crop_000.tif'

In [None]:
# Let's get a quick-and-dirty estimate of image summary stats
norm_params = {'mean': np.mean(img), 'std': np.std(img)}

## Define Dataset subclass

In [None]:
class SBEMCrop2dDataset(Dataset):
    
    def __init__(self, data_root, norm_params):
        self.data_fnames = os.listdir(data_root)
        self.data_root = data_root
        self.norm_params = norm_params
    
    def __len__(self):
        return len(self.data_fnames)
    
    def __getitem__(self, idx):
        img = np.asarray(plt.imread(os.path.join(self.data_root, self.data_fnames[idx])))
        img = self._normalize(img)
        sample = SBEMCrop2dDataset._reshape_to_torch(img)
        return sample
    
    def _normalize(self, img):
        img = (np.asarray(img)-self.norm_params['mean'])/self.norm_params['std']
        return img
    
    @staticmethod
    def _reshape_to_torch(img):
        sample = torch.from_numpy(np.reshape(img, (1, img.shape[0], img.shape[1]))).float()
        return sample

## Instantiate dataset subclass

In [None]:
sbem_dataset = SBEMCrop2dDataset(data_root, norm_params)

## Test dataset subclass

In [None]:
# Show first 3 images in dataset
fig, axs = plt.subplots(1, 3, figsize=(16,12))

for i in range(3):
    sample = sbem_dataset[i]
    img = sample.data.numpy().squeeze()
    axs[i].imshow(img, cmap='gray')
    axs[i].set_title('mean: {:0.3f}, std: {:0.3f}'.format(np.mean(img), np.std(img)))

In [None]:
img.dtype

# Model

## Define Model

In [None]:
class ConvAE2D(torch.nn.Module):
    def __init__(self):
        super(ConvAE2D, self).__init__()
        self.encoding_layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            torch.nn.ReLU())
        self.encoding_layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(16, 8, kernel_size=5, stride=1, padding=2),
            torch.nn.ReLU())
        self.encoding_layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(8, 4, kernel_size=5, stride=1, padding=2),
            torch.nn.ReLU())
        self.decoding_layer1 = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(4, 8, kernel_size=5, stride=1, padding=2),
            torch.nn.ReLU())
        self.decoding_layer2 = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(8, 16, kernel_size=5, stride=1, padding=2),
            torch.nn.ReLU())
        self.decoding_layer3 = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(16, 1, kernel_size=5, stride=1, padding=2),
            torch.nn.Sigmoid())
        
    def forward(self, x):
        x = self.encoding_layer1(x)
        x = self.encoding_layer2(x)
        x = self.encoding_layer3(x)
        x = self.decoding_layer1(x)
        x = self.decoding_layer2(x)
        x = self.decoding_layer3(x)
        return x

In [None]:
class ConvAE2D_2(torch.nn.Module):
    def __init__(self):
        super(ConvAE2D_2, self).__init__()
        self.encoding_layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU())
        self.encoding_pool1 = torch.nn.MaxPool2d(2)
        self.encoding_layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU())
        self.encoding_pool2 = torch.nn.MaxPool2d(2)
        self.encoding_layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(8, 4, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU())
        self.decoding_layer1 = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(4, 8, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU())
        self.decoding_up1 = torch.nn.Upsample(scale_factor=2, mode='nearest')
        self.decoding_layer2 = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(8, 16, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU())
        self.decoding_up2 = torch.nn.Upsample(scale_factor=2, mode='nearest')
        self.decoding_layer3 = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(16, 1, kernel_size=3, stride=1, padding=0),
            torch.nn.Sigmoid())
        
    def forward(self, x):
        x = self.encoding_layer1(x)
        x = self.encoding_pool1(x)
        x = self.encoding_layer2(x)
        x = self.encoding_pool2(x)
        x = self.encoding_layer3(x)
        x = self.decoding_layer1(x)
        x = self.decoding_up1(x)
        x = self.decoding_layer2(x)
        x = self.decoding_up2(x)
        x = self.decoding_layer3(x)
        return x

## Instantiate Model, Loss and Optimizer

In [None]:
net = ConvAE2D_2()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

# Train Model

In [None]:
trainloader = DataLoader(sbem_dataset, batch_size=4, shuffle=True, num_workers=0)
n_epoch = 30
for epoch in range(n_epoch):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs = data
        labels = inputs

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 10 == 9:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 10))
            running_loss = 0.0

In [None]:
# Show input vs output
fig, axs = plt.subplots(1, 2, figsize=(16,12))

img_input = inputs[0].data.numpy().squeeze()
axs[0].imshow(img_input, cmap='gray')

img_output = outputs[0].data.numpy().squeeze()
axs[1].imshow(img_output, cmap='gray')