## Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset, DataLoader
from torch.optim.optimizer import Optimizer
from tqdm import tqdm
import datetime

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CANVAS_SIZE: int = (256, 256)

## Visualize

In [None]:
def show(img: np.ndarray) -> None:
    plt.imshow(img, cmap='gray')
    plt.axis('off')
    plt.show()

In [None]:
def visualize_ellipses(img1, img2, img3):
    fig, axs = plt.subplots(1, 3, figsize=(10, 5))
    axs[0].imshow(img1, cmap='gray')
    axs[0].axis('off')
    axs[0].set_title("With holes")
    
    axs[1].imshow(img2, cmap='gray')
    axs[1].axis('off')
    axs[1].set_title("Target")
    
    axs[2].imshow(img3, cmap='gray')
    axs[2].axis('off')
    axs[2].set_title("Autoencoder output")
    plt.show()

## Dataset

In [None]:
class EllipsesDataset(Dataset):
    def __init__(self, transform = None):
        self.transform = transform
            
    def __len__(self) -> int:
        return 1000
    
    def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
        target_img = self.__generate_random_ellipse()
        input_img = self.__add_random_holes(target_img)
        
        if self.transform:
            target_img = self.transform(target_img)
            input_img = self.transform(input_img)
        
        return input_img, target_img
    
    
    def __generate_random_ellipse(self, img_size: tuple[int, int] = (256,256)) -> np.ndarray:
        img = np.zeros(img_size, dtype=np.uint8)
        center: tuple[int, int] = (
            np.random.randint(int(img_size[0] * 0.2), int(img_size[0] * 0.8)),
            np.random.randint(int(img_size[1] * 0.2), int(img_size[1] * 0.8))
        )
        axes: tuple[int, int] = (
            np.random.randint(int(img_size[0] * 0.2), int(img_size[0] * 0.6)),
            np.random.randint(int(img_size[1] * 0.2), int(img_size[1] * 0.6))
        )
        angle: int = np.random.randint(0, 180)
        cv2.ellipse(img, center, axes, angle, 0, 360, 255, -1)

        # show(img)
        return img

    def __add_random_holes(self, img: np.ndarray, hole_count: int = 30, max_hole_radius: int = 30) -> np.ndarray:
        img = img.copy()
        ellipse_points: np.ndarray = np.argwhere(img == 255)

        hole_count: int  = np.random.randint(5, hole_count)
        for _ in range(hole_count):
            if len(ellipse_points) < 50:
                break  # otherwise looks like a swiss cheese
            
            center_y, center_x = ellipse_points[np.random.choice(len(ellipse_points))]
            radius = np.random.randint(5, max_hole_radius)
            cv2.circle(img, (center_x, center_y), radius, 0, -1)

        return img

## PyTorch

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, 3, padding=1),  # b, 8, 256, 256
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=3, padding=1),  # b, 16, 85, 85
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # b, 16, 42, 42
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 8, 3, stride=2, padding=1),  # b, 8, 21, 21
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2), # b, 8, 10, 10
            nn.Conv2d(8, 8, 3, stride=2, padding=1),  # b, 8, 5, 5
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2)  # b, 8, 3, 3
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 5, stride=2, output_padding=1, padding=1),  # b, 16, 8, 8
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, output_padding=1, padding=1),  # b, 8, 16, 16
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 8, 3, stride=2, output_padding=1, padding=1),  # b, 8, 32, 32
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 4, 3, stride=2, output_padding=1, padding=1),  # b, 1, 64, 64
            nn.ReLU(True),
            nn.ConvTranspose2d(4, 4, 3, stride=2, output_padding=1, padding=1),  # b, 1, 128, 128
            nn.BatchNorm2d(4),
            nn.ReLU(True),
            nn.ConvTranspose2d(4, 4, 3, stride=2, output_padding=1, padding=1),  # b, 1, 256, 256
            nn.ReLU(True),
            nn.Conv2d(4, 4, 3, stride=1, padding=1),
            nn.ReLU(True),
            nn.Conv2d(4, 1, 3, stride=1, padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
def train_model(model: Autoencoder, criterion: torch.nn.Module, optimizer: Optimizer, dataloader: dict, EPOCHS=5) -> tuple[Autoencoder, dict[str, list]]:
    loss_history: list = []
    accuracy_history: list = []
    val_loss_history: list = []
    val_accuracy_history: list = []
    
    for epoch in range(EPOCHS):
        model.train()
        running_loss: float = 0.0
        running_corrects: int = 0
        total_pixels: int = 0
        
        # Training
        for inputs, targets in tqdm(dataloader['train']):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # loss
            running_loss += loss.item()
            # accuracy
            preds = (outputs > 0.5).float()
            running_corrects += torch.sum(preds == targets)
            total_pixels += targets.numel()
        
        epoch_loss = running_loss / len(dataloader['train'])
        epoch_accuracy = running_corrects.double() / total_pixels
        loss_history.append(epoch_loss)
        accuracy_history.append(epoch_accuracy.item())
        
        # Validation
        model.eval()
        val_running_loss = 0.0
        val_running_corrects = 0
        val_total_pixels = 0
        with torch.no_grad():
            for val_inputs, val_targets in dataloader['val']:
                val_inputs = val_inputs.to(device)
                val_targets = val_targets.to(device)
                val_outputs = model(val_inputs)
                
                # loss
                val_loss = criterion(val_outputs, val_targets)
                val_running_loss += val_loss.item()
                
                # Calculate accuracy
                val_preds = (val_outputs > 0.5).float()
                val_running_corrects += torch.sum(val_preds == val_targets)
                val_total_pixels += val_targets.numel()
        
        val_epoch_loss = val_running_loss / len(dataloader['val'])
        val_epoch_accuracy = val_running_corrects.double() / val_total_pixels
        val_loss_history.append(val_epoch_loss)
        val_accuracy_history.append(val_epoch_accuracy.item())
        
        visualize_ellipses(inputs[0].cpu().detach().numpy().squeeze(), targets[0].cpu().detach().numpy().squeeze(), outputs[0].cpu().detach().numpy().squeeze())
        
        print(f'Epoch {epoch+1}/{EPOCHS} Loss: {epoch_loss:.4f} Accuracy: {epoch_accuracy:.4f} Val Loss: {val_epoch_loss:.4f} Val Accuracy: {val_epoch_accuracy:.4f}')
        
    history: dict[str, list] = {
        'loss': loss_history,
        'accuracy': accuracy_history,
        'val_loss': val_loss_history,
        'val_accuracy': val_accuracy_history
    }
    return model, history

## Train

In [None]:
transform = torchvision.transforms.Compose(transforms=[
    torchvision.transforms.ToTensor()
])

datasets: dict[str, Dataset] = {
    'train': EllipsesDataset(transform=transform),
    'val': EllipsesDataset(transform=transform)
}

dataloaders: dict[str, DataLoader] = {
    'train': DataLoader(datasets['train'], batch_size=32, shuffle=True),
    'val': DataLoader(datasets['val'], batch_size=32, shuffle=False)
}

model = Autoencoder().to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
model, history = train_model(model, criterion, optimizer, dataloaders, EPOCHS=1000)

# save model
torch.save(model.state_dict(), f'results/autoencoder_{datetime.datetime.now().strftime("%Y-%m-%d")}_{history["val_accuracy"][-1]}.pt')

## Load & Run

In [None]:
# file_path: str = "results/" + "autoencoder_2024-11-11_[0.5848712158203125].pt"
# autoencoder = Autoencoder()
# autoencoder.load_state_dict(torch.load(file_path))
# autoencoder.to(device)

# # Test
# autoencoder.eval()
# with torch.no_grad():
#     for inputs, targets in dataloaders['val']:
#         inputs = inputs.to(device)
#         targets = targets.to(device)
#         outputs = autoencoder(inputs)
#         visualize_ellipses(inputs[0].cpu().detach().numpy().squeeze(), targets[0].cpu().detach().numpy().squeeze(), outputs[0].cpu().detach().numpy().squeeze())
#         break