In [None]:
import torch
import torch.nn as nn

from torchinfo import summary

In [None]:
def build_circle_segmenter():
    circle_segmenter = nn.Sequential(
        nn.Conv2d(
            in_channels=1,
            out_channels=8,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.ReLU(),
        nn.Conv2d(
            in_channels=8,
            out_channels=8,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.ReLU(),
        nn.MaxPool2d(
            kernel_size=2,
            stride=2
        ),

        nn.Conv2d(
            in_channels=8,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.ReLU(),
        nn.Conv2d(
            in_channels=16,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.ReLU(),
        nn.MaxPool2d(
            kernel_size=2,
            stride=2
        ),

        nn.Conv2d(
            in_channels=16,
            out_channels=32,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.ReLU(),
        nn.Conv2d(
            in_channels=32,
            out_channels=32,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.ReLU(),
        #nn.MaxPool2d(
        #    kernel_size=2,
        #    stride=2
        #),

        #nn.AdaptiveAvgPool2d((8, 8)),
        nn.ConvTranspose2d(
            in_channels=32,
            out_channels=16,
            kernel_size=2,
            stride=2,
            padding=0
        ),
        nn.ReLU(),
        nn.Conv2d(
            in_channels=16,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.ReLU(),

        nn.ConvTranspose2d(
            in_channels=16,
            out_channels=8,
            kernel_size=2,
            stride=2,
            padding=0
        ),
        nn.ReLU(),
        nn.Conv2d(
            in_channels=8,
            out_channels=1,
            kernel_size=3,
            stride=1,
            padding=1
        ),
        nn.ReLU(),
    )
    return circle_segmenter

In [None]:
summary(build_circle_segmenter(), input_size=(2, 1, 128, 128))

In [None]:
import itertools

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw

In [None]:
def get_empty_image(image_width, image_height):
    # mode="F" for 32-bit floating point pixels
    # mode="LA" for 8-bit grayscale with alpha channel
    empty_image = Image.new(
        mode="F",
        size=(image_width, image_height),
        color=255
    )

    return empty_image

In [None]:
def draw_a_circle(target_image, circle_e1, circle_e2, circle_radius, outline_color=255):
    """
    The most simple image of a circle?
    """
    artist = ImageDraw.ImageDraw(target_image)
    artist.arc(
        (
            circle_e1 - circle_radius/2,
            circle_e2 - circle_radius/2,
            circle_e1 + circle_radius/2,
            circle_e2 + circle_radius/2
        ),
        start=0,
        end=360,
        width=1,
        fill=outline_color
    )
    
    return target_image


In [None]:
def fill_a_circle(target_image, circle_e1, circle_e2, circle_radius, circle_fill_color=255):
    artist = ImageDraw.ImageDraw(target_image)
    artist.ellipse(
        (
            circle_e1 - circle_radius/2,
            circle_e2 - circle_radius/2,
            circle_e1 + circle_radius/2,
            circle_e2 + circle_radius/2
        ),
        width=1,
        #outline=255,  # what happens without this?
        fill=circle_fill_color
    )

    return target_image

In [None]:
def image_of_circles(circle_count):
    """
    The most simple image of a circle?
    """
    image_of_filled_circles = get_empty_image(
        image_width=128,
        image_height=128
    )

    circle_radius_list = list()
    image_of_outlined_circles = get_empty_image(
        image_width=128,
        image_height=128
    )

    circle_parameters_list = list()
    for circle_i in range(circle_count):
        circle_radius = np.random.randint(low=10, high=40)
        circle_radius_list.append(circle_radius)
        circle_e1=np.random.randint(low=20, high=80)
        circle_e2=np.random.randint(low=20, high=80)

        circle_fill_color = np.random.randint(low=100, high=200)

        circle_parameters_list.append(
            {
                "circle_radius": circle_radius,
                "circle_e1": circle_e1,
                "circle_e2": circle_e2,
                "circle_fill_color": circle_fill_color
            }   
        )

    for circle_parameters in circle_parameters_list:
        # these are the input training images
        fill_a_circle(
            target_image=image_of_filled_circles,
            **circle_parameters
            #circle_e1=circle_e1,
            #circle_e2=circle_e2,
            #circle_radius=circle_radius,
        )
        # these are the "target" training images
        fill_a_circle(
            target_image=image_of_outlined_circles,
            **circle_parameters
            #circle_e1=circle_e1,
            #circle_e2=circle_e2,
            #circle_radius=circle_radius,
        )

    # draw outlines on the "target" training images
    for circle_parameters in circle_parameters_list:
        circle_parameters.pop("circle_fill_color")

        draw_a_circle(
            target_image=image_of_outlined_circles,
            outline_color=0,
            **circle_parameters,
            #circle_e1=circle_e1,
            #circle_e2=circle_e2,
            #circle_radius=circle_radius,
        )


    return (circle_radius_list, image_of_filled_circles, image_of_outlined_circles)

In [None]:
import matplotlib.pyplot as plt

# generate 4 sets of input/output
circle_radiuses_list = list()
input_images = list()
output_images = list()

for _ in range(4):
    circle_count = np.random.randint(low=1, high=5)
    circle_radiuses, input_circles_image, output_circles_image = image_of_circles(circle_count=circle_count)
    circle_radiuses_list.append(circle_radiuses)
    input_images.append(input_circles_image)
    output_images.append(output_circles_image)

fig, axs = plt.subplots(nrows=2, ncols=2)
for i, (r, c) in enumerate(itertools.product(range(2), range(2))):
    print('circle radiuses: {}'.format(circle_radiuses))

    axs[r][c].imshow(input_images[i], origin="lower")
    #print(np.array(im))

fig, axs = plt.subplots(nrows=2, ncols=2)
for i, (r, c) in enumerate(itertools.product(range(2), range(2))):
    print('circle radiuses: {}'.format(circle_radiuses))

    axs[r][c].imshow(output_images[i], origin="lower")
    #print(np.array(im))


In [None]:
# a class to interact with DataLoaders
class CircleImageDataset:
    def __init__(self, circle_image_count):
        self.circle_image_list = list()
        for i in range(circle_image_count):
            
            circle_count = np.random.randint(low=1, high=5)
            circle_radius_list, input_circles_image, target_circles_image = image_of_circles(
                circle_count=circle_count
            )

            # sort the circle radiuses in descending order
            # otherwise the training data is a little ambiguous?
            #sorted_circle_radius_list = sorted(circle_radius_list, reverse=True)

            # the network output is a 8-element array of circle radiuses
            #circle_radiuses = np.zeros((8, ), dtype=np.float32)
            #circle_radiuses[:circle_count] = sorted_circle_radius_list
            
            self.circle_image_list.append(
                (
                    # get the right type here - single precision floating point
                    # this depends on how the optimization is handled
                    # but I want to get it right here
                    #circle_radiuses,

                    # the PIL image is converted to a 2D numpy array here
                    # in addition an extra dimension is inserted for 'channel'
                    # which PyTorch convolutional networks expect
                    np.expand_dims(
                        np.array(target_circles_image),
                        axis=0
                    ),

                    # the PIL image is converted to a 2D numpy array here
                    # in addition an extra dimension is inserted for 'channel'
                    # which PyTorch convolutional networks expect
                    np.expand_dims(
                        np.array(input_circles_image),
                        axis=0
                    )
                )
            )

    def __getitem__(self, index):
        # self.circle_image_list looks like
        #   [ (radius_0, radius_1, ...), image_0), (radius_0, radius_1, ...), image_1), ...]
        # this dataset returns only (radius, image)
        return self.circle_image_list[index]

    def __len__(self):
        return len(self.circle_image_list)

In [None]:
def test_circle_image_dataset():
    circle_image_dataset = CircleImageDataset(100)
    print(f"len(circle_image_dataset): {len(circle_image_dataset)}")
    target_circle_image, input_circle_image = circle_image_dataset[99]
    print(f"target image.shape : {target_circle_image.shape}")
    print(f"input image.shape : {input_circle_image.shape}")

test_circle_image_dataset()

In [None]:
from torch.utils.data import DataLoader

In [None]:
def test_circle_image_dataloader():
    circle_image_dataloader = DataLoader(CircleImageDataset(circle_image_count=100), batch_size=10)
    for batch in circle_image_dataloader:
        print(f"len(batch): {len(batch)}")
        print(f"len(batch[0]): {len(batch[0])}")
        print(f"batch[0].shape: {batch[0].shape}")
        print(f"len(batch[1]): {len(batch[1])}")
        print(f"batch[1].shape: {batch[1].shape}")

        target_circle_images, input_circle_images = batch

        # note correct_radii.shape does not match predicted_radii.shape
        print(f"target_circle_images.shape: {target_circle_images.shape}")
        print(f"target_circle_images.dtype: {target_circle_images.dtype}")
        print(f"input_circle_images.shape:  {input_circle_images.shape}")

        test_circle_segmenter = build_circle_segmenter()
        predicted_circle_images = test_circle_segmenter(input_circle_images)
        print(f"predicted_circle_images.shape: {predicted_circle_images.shape}")
        print(f"predicted_circle_images.dtype: {predicted_circle_images.dtype}")
        
        break

test_circle_image_dataloader()

In [None]:
# 100,000, no shuffle, works, 20 epochs is ok but sometimes training does not progress
# 100,000 with shuffling has more stable training
# 10,000, no shuffle, works, 50 epochs is ok
train_circle_image_loader = DataLoader(
    CircleImageDataset(circle_image_count=10000),
    batch_size=100,
    shuffle=True
)

test_circle_image_loader = DataLoader(
    CircleImageDataset(circle_image_count=1000),
    batch_size=100
)

#validate_circle_image_loader = DataLoader(CircleImageDataset(circle_image_count=1000), batch_size=100)

In [None]:
len(train_circle_image_loader.dataset)

In [None]:
def train(
    circle_segmenter_model,
    optimizer,
    loss_function,
    train_dataloader,
    test_dataloader,
    epoch_count
):
    
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    circle_segmenter_model.to(device)

    for epoch_i in range(epoch_count):
        training_loss = 0.0
        circle_segmenter_model.train()
        for correct_segmented_circle_images, circle_images in train_dataloader:
            optimizer.zero_grad()

            # torch calls circle_images 'inputs'
            circle_images = circle_images.to(device)
            # make the correct_radii array match predicted_radii.shape
            #correct_radii = torch.unsqueeze(correct_radii, 1)
            correct_segmented_circle_images = correct_segmented_circle_images.to(device)

            predicted_radii = circle_segmenter_model(circle_images)
            
            loss = loss_function(predicted_radii, correct_segmented_circle_images)
            loss.backward()
            optimizer.step()

            training_loss += loss.data.item()

        training_loss /= len(train_circle_image_loader.dataset)

        test_loss = 0.0
        circle_segmenter_model.eval()
        for correct_segmented_circle_images, circle_images in test_dataloader:

            # torch calls circle_images 'inputs'
            circle_images = circle_images.to(device)
            #inputs = inputs.to(device)
            # make correct_radii have the same shape as predicted_radii
            #correct_radii = torch.unsqueeze(correct_radii, 1)
            correct_segmented_circle_images = correct_segmented_circle_images.to(device)

            predicted_radii = circle_segmenter_model(circle_images)

            loss = loss_function(predicted_radii, correct_segmented_circle_images)
            test_loss += loss.data.item()

        test_loss /= len(test_dataloader.dataset)

        print(
            #'Epoch: {}, Training Loss: {:.2f}, Test Loss: {:.2f}, percent_wrong = {}'.format(
            'Epoch: {}, Training Loss: {:.2f}, Test Loss: {:.2f}'.format(
                epoch_i, training_loss, test_loss
            )
        )

In [None]:
import torch.optim

circle_segmenter = build_circle_segmenter()
train(
    circle_segmenter,
    torch.optim.Adam(circle_segmenter.parameters()),
    torch.nn.MSELoss(),
    train_circle_image_loader,
    test_circle_image_loader,
    epoch_count=100
)

In [None]:
# try out the circle segmenter
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

    circle_segmenter.eval()
a_circle_image_dataloader = DataLoader(CircleImageDataset(10), batch_size=1)
for a_circle_target_image, a_circle_input_image in a_circle_image_dataloader:
    print(f"input shape : {a_circle_input_image.shape}")

    a_segmented_circle_tensor = circle_segmenter(a_circle_input_image.to(device))
    a_segmented_circle_image = a_segmented_circle_tensor.cpu().detach().numpy()
    
    print(f"output shape: {a_segmented_circle_image.shape}")
    
    fig, axs = plt.subplots(nrows=1, ncols=2)
    #for i, (r, c) in enumerate(itertools.product(range(1), range(2))):
    #print('circle radiuses: {}'.format(circle_radiuses))
    axs[0].imshow(a_circle_input_image[0, 0, :, :], origin="lower")
    axs[1].imshow(a_segmented_circle_image[0, 0, :, :], origin="lower")
    #print(np.array(im))
