In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# Import necessary libraries
import os
import sys

from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms

# Directory path used in Google Colab
# project_dir = '/content/drive/MyDrive/Colab Notebooks/HGCAL/visual-inspection'

# Directory path used in local
project_dir = '../'

current_dir = os.path.join(project_dir, 'autoencoder')
sys.path.append(current_dir)

from data_loading import *
from training import *

# Set the seed
torch.manual_seed(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Path to the datasets folder
DATASET_PATH = os.path.join(project_dir, 'datasets')
CHECKPOINT_PATH = os.path.join(current_dir, 'small_ae.pt')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
# Read in the image
image = Image.open(os.path.join(DATASET_PATH, 'unperturbed_data', 'good_hexaboard.png'))

# Get the height and width of the image
width, height = image.size
print('Image width:', width)
print('Image height:', height)

In [None]:
# Adjust the number of segments
# THIS SHOULD WORK WITH THE GUI
NUM_VERTICAL_SEGMENTS = 20
NUM_HORIZONTAL_SEGMENTS = 12

# Define the transformations
transform = transforms.Compose([
    RotationAndSegmentationTransform(
        height=height,
        width=width,
        vertical_segments=NUM_VERTICAL_SEGMENTS,
        horizontal_segments=NUM_HORIZONTAL_SEGMENTS
    ),
    # transforms.RandomRotation(degrees=2),
    # transforms.RandomHorizontalFlip(),
    # transforms.RandomVerticalFlip(),
])

# Read in and process the iamges
train_dataset = HexaboardDataset(
    image_dir=os.path.join(DATASET_PATH, 'unperturbed_data'),
    transform=transform
)

# Set the batch and chunk size
batch_size = 1
chunk_size = 12

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
# # Save the segments to a folder to synthesize perturbed data
# data_iter = next(iter(train_loader)).squeeze(0)
# # print(data_iter.shape)

# for i, image in enumerate(data_iter):
#     image = image.squeeze(0).moveaxis(0,2).numpy() * 255
#     image = Image.fromarray(np.asarray(np.clip(image, 0, 255), dtype="uint8"))
#     image.save(os.path.join(DATASET_PATH, 'perturbed_data', f'bad_hexaboard_{i + 1}.png'))

In [None]:
# Function to remove the transparency channel
def remove_transparency(
    image_dir: str,
):
    image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]

    for image_path in image_paths:
        img = Image.open(image_path)
        img = np.array(img)
        img = img[:, :, 0:3]
        img = Image.fromarray(img)
        img.save(image_path)

remove_transparency(os.path.join(DATASET_PATH, 'perturbed_data'))

In [None]:
# Adjust the number of segments
# THIS SHOULD WORK WITH THE GUI
NUM_VERTICAL_SEGMENTS = 20
NUM_HORIZONTAL_SEGMENTS = 12

# Define the transformations
transform = transforms.Compose([
    RotationAndSegmentationTransform(
        height=height,
        width=width,
        vertical_segments=NUM_VERTICAL_SEGMENTS,
        horizontal_segments=NUM_HORIZONTAL_SEGMENTS
    ),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
])

# Read in and process the images
train_dataset = HexaboardDataset(
    image_dir=os.path.join(DATASET_PATH, 'unperturbed_data'),
    transform=transform
)
val_dataset = HexaboardDataset(
    image_dir=os.path.join(DATASET_PATH, 'perturbed_data'),
    transform=transforms.Compose([transforms.ToTensor()])
)

# Set the batch and chunk size
batch_size = 1
chunk_size = 12

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

# Print some information about the data
print(f'Train dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(val_dataset)}')
# print(f'Test dataset size: {len(test_dataset)}')
print(f'Segments Shape: {val_dataset[0].shape}')
print(f'Image shape: {val_dataset[0][0].shape}')
print(f'Image tensor type: {train_dataset[0][0].dtype}')
print(f'Batches: {len(train_loader)}')

In [None]:
# Get the segments' height and width
segment_height = train_dataset[0][0][0].shape[0]
segment_width = train_dataset[0][0][0].shape[1]
print('Segment height:', segment_height)
print('Segment width:', segment_width)

In [None]:
# Initialize the model
cnn_ae = SimpleCNNAutoEncoder(
    height=segment_height,
    width=segment_width,
    latent_dim=128,
    kernel_sizes=[64, 128]
)
cnn_ae.to(device)

In [None]:
optimizer = optim.Adam(cnn_ae.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
criterion = nn.BCEWithLogitsLoss()

# Train the model
history, cnn_ae = train_autoencoder(
    model=cnn_ae,
    criterion=criterion,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    scheduler=scheduler,
    num_epochs=500,
    save_path=CHECKPOINT_PATH
)

In [None]:
# View the training progress
plot_metrics(history)

In [None]:
# Load the model's weights
cnn_ae.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device))

In [None]:
# Evaluate the model on unperturbed vs. perturbed images
criterion = nn.BCEWithLogitsLoss()

evaluate_autoencoder(
    model=cnn_ae,
    criterion=criterion,
    train_loader=train_loader,
    test_loader=val_loader,
    num_images=6,
    visualize=True
)