## Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset, DataLoader
from torch.optim.optimizer import Optimizer
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CANVAS_SIZE: int = (256, 256)

In [None]:
def show(img: np.ndarray) -> None:
    plt.imshow(img, cmap='gray')
    plt.axis('off')
    plt.show()

## Ellipses

In [None]:
class Ellipses:
    def __init__(self) -> None:
        self.img_size: int = CANVAS_SIZE

    def generate_random_ellipse(self) -> np.ndarray:
        img = np.zeros(self.img_size, dtype=np.uint8)
        center: tuple[int, int] = (
            np.random.randint(int(self.img_size[0] * 0.2), int(self.img_size[0] * 0.8)),
            np.random.randint(int(self.img_size[1] * 0.2), int(self.img_size[1] * 0.8))
        )
        axes: tuple[int, int] = (
            np.random.randint(int(self.img_size[0] * 0.2), int(self.img_size[0] * 0.6)),
            np.random.randint(int(self.img_size[1] * 0.2), int(self.img_size[1] * 0.6))
        )
        angle: int = np.random.randint(0, 180)
        cv2.ellipse(img, center, axes, angle, 0, 360, 255, -1)
        
        # show(img)
        return img


    def add_random_holes(self, img: np.ndarray, hole_count: int = 30, max_hole_radius: int = 30) -> np.ndarray:
        ellipse_points: np.ndarray = np.argwhere(img == 255)

        hole_count: int  = np.random.randint(5, hole_count)
        for _ in range(hole_count):
            if len(ellipse_points) < 50:
                break  # otherwise looks like a swiss cheese
            
            center_y, center_x = ellipse_points[np.random.choice(len(ellipse_points))]
            radius = np.random.randint(5, max_hole_radius)
            cv2.circle(img, (center_x, center_y), radius, 0, -1)

        return img
    


In [None]:
def visualize_ellipses(img1, img2, img3):
    fig, axs = plt.subplots(1, 3, figsize=(10, 5))
    axs[0].imshow(img1, cmap='gray')
    axs[0].axis('off')
    axs[0].set_title("Original")
    
    axs[1].imshow(img2, cmap='gray')
    axs[1].axis('off')
    axs[1].set_title("With holes")
    
    axs[2].imshow(img3, cmap='gray')
    axs[2].axis('off')
    axs[2].set_title("With holes")
    plt.show()

## PyTorch

In [None]:
class EllipsesDataset(Dataset):
    def __init__(self, transform = None):
        self.transform = transform
        ellipses_generator = Ellipses()
        self.ellipses_generator = ellipses_generator
        
        self.inputs: list = []
        self.targets: list = []
        for _ in range(100):
            img = ellipses_generator.generate_random_ellipse()
            self.inputs.append(img)
            img = ellipses_generator.add_random_holes(img)
            self.targets.append(img)
            
            
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        input_img = self.inputs[idx]
        target_img = self.targets[idx]
        
        if self.transform:
            input_img = self.transform(input_img)
            target_img = self.transform(target_img)
        
        return input_img, target_img
    

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=3, padding=1),  # b, 16, 10, 10
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # b, 16, 5, 5
            nn.Conv2d(16, 8, 3, stride=2, padding=1),  # b, 8, 3, 3
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=1)  # b, 8, 2, 2
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, stride=2),  # b, 16, 5, 5
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),  # b, 8, 15, 15
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),  # b, 1, 28, 28
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
def train_model(model: Autoencoder, criterion: torch.nn.Module, optimizer: Optimizer, dataloader: dict, EPOCHS=5) -> tuple[Autoencoder, list, list, list, list]:
    loss_history: list = []
    accuracy_history: list = []
    val_loss_history: list = []
    val_accuracy_history: list = []
    
    for epoch in range(EPOCHS):
        model.train()
        running_loss: float = 0.0
        running_corrects: int = 0
        total_pixels: int = 0
        
        # Training
        for inputs, targets in tqdm(dataloader['train']):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # loss
            running_loss += loss.item()
            # accuracy
            preds = (outputs > 0).float()
            running_corrects += torch.sum(preds == targets)
            total_pixels += targets.numel()
        
        epoch_loss = running_loss / len(dataloader['train'])
        epoch_accuracy = running_corrects.double() / total_pixels
        loss_history.append(epoch_loss)
        accuracy_history.append(epoch_accuracy.item())
        
        # Validation
        model.eval()
        val_running_loss = 0.0
        val_running_corrects = 0
        val_total_pixels = 0
        with torch.no_grad():
            for val_inputs, val_targets in dataloader['val']:
                val_inputs = val_inputs.to(device)
                val_targets = val_targets.to(device)
                val_outputs = model(val_inputs)
                
                # loss
                val_loss = criterion(val_outputs, val_targets)
                val_running_loss += val_loss.item()
                
                # Calculate accuracy
                val_preds = (val_outputs > 0).float()
                val_running_corrects += torch.sum(val_preds == val_targets)
                val_total_pixels += val_targets.numel()
        
        val_epoch_loss = val_running_loss / len(dataloader['val'])
        val_epoch_accuracy = val_running_corrects.double() / val_total_pixels
        val_loss_history.append(val_epoch_loss)
        val_accuracy_history.append(val_epoch_accuracy.item())
        
        print(f'Epoch {epoch+1}/{EPOCHS} Loss: {epoch_loss:.4f} Accuracy: {epoch_accuracy:.4f} Val Loss: {val_epoch_loss:.4f} Val Accuracy: {val_epoch_accuracy:.4f}')
        
    history = {
        'loss': loss_history,
        'accuracy': accuracy_history,
        'val_loss': val_loss_history,
        'val_accuracy': val_accuracy_history
    }
    return model, history

## Train

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

datasets: dict[str, Dataset] = {
    'train': EllipsesDataset(transform=transform),
    'val': EllipsesDataset(transform=transform)
}

dataloaders: dict[str, DataLoader] = {
    'train': DataLoader(datasets['train'], batch_size=32, shuffle=True),
    'val': DataLoader(datasets['val'], batch_size=32, shuffle=False)
}

model = Autoencoder().to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

_, history = train_model(model, criterion, optimizer, dataloaders, EPOCHS=5)