### Dependencies

In [None]:
import glob
import logging
import os
import random
from copy import deepcopy

import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from PIL import Image
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision.transforms import ToTensor
from tqdm import tqdm

logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)

### Model

In [None]:
class SimpleNet(nn.Module):
    def __init__(
        self,
        slice_depth: int = 65,
    ):
        super().__init__()
        self.conv1 = nn.Conv2d(slice_depth, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.LazyLinear(120)
        self.fc2 = nn.LazyLinear(84)
        self.fc3 = nn.LazyLinear(1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

### Utils

In [None]:
def rle(img, threshold=0.5):
    # TODO: Histogram of image to see where threshold should be
    flat_img = img.flatten()
    flat_img = np.where(flat_img > threshold, 1, 0).astype(np.uint8)
    starts = np.array((flat_img[:-1] == 0) & (flat_img[1:] == 1))
    ends = np.array((flat_img[:-1] == 1) & (flat_img[1:] == 0))
    starts_ix = np.where(starts)[0] + 2
    ends_ix = np.where(ends)[0] + 2
    lengths = ends_ix - starts_ix
    return starts_ix, lengths

def load_image(
    image_filepath: str = 'image.png',
    image_type: str = None,
    tensor_dtype: torch.dtype = torch.float32,
    resize_ratio: float = 1.0,
    viz: bool = False,
) -> torch.Tensor:
    _image = Image.open(image_filepath)
    if image_type is not None:
        # ‘L’ (8-bit pixels, grayscale)
        # ‘F’ (32-bit floating point pixels)
        log.debug(f"Converting {image_filepath} to {image_type}")
        _image = _image.convert(image_type)
    if resize_ratio != 1.0:
        log.debug(f"Resizing {image_filepath} by {resize_ratio}")
        _image = _image.resize((
            int(_image.width * resize_ratio),
            int(_image.height * resize_ratio)
        ), resample=Image.BILINEAR)
    _pt = ToTensor()(_image)
    _pt = _pt.to(tensor_dtype)
    return _pt

def get_device():
    if torch.cuda.is_available():
        log.info("Using GPU")
        return torch.device("cuda")
    else:
        log.info("Using CPU")
        return torch.device("cpu")

### Dataset

In [None]:
class ClassificationDataset(data.Dataset):

    def __init__(
        self,
        # Directory containing the dataset
        data_dir: str,
        # Filenames of the images we'll use
        image_mask_filename='mask.png',
        image_labels_filename='inklabels.png',
        image_ir_filename='ir.png',
        slices_dir_filename='surface_volume',
        # Expected slices per fragment
        slice_depth: int = 65,
        # Size of an individual patch
        patch_size_x: int = 1028,
        patch_size_y: int = 256,
        # Image resize ratio
        resize_ratio: float = 1.0,
        # Dataset datatype
        dataset_dtype: torch.dtype = torch.float32,
        # Training vs Testing mode
        train: bool = True,
        # Visualize the dataset
        viz: bool = False,
    ):
        self.train = train
        self.viz = viz
        log.info(f"Initializing FragmentDataset with data_dir={data_dir}")
        # Verify paths and directories for images and metadata
        self.data_dir = data_dir
        assert os.path.exists(
            data_dir), f"Data directory {data_dir} does not exist"
        self.image_mask_filename = image_mask_filename
        self.image_mask_filepath = os.path.join(
            data_dir, self.image_mask_filename)
        assert os.path.exists(
            self.image_mask_filepath), f"Mask file {self.image_mask_filepath} does not exist"
        self.image_labels_filename = image_labels_filename
        self.image_labels_filepath = os.path.join(
            data_dir, self.image_labels_filename)
        self.slices_dir = os.path.join(data_dir, slices_dir_filename)
        assert os.path.exists(
            self.slices_dir), f"Slices directory {self.slices_dir} does not exist"
        if self.train:
            assert os.path.exists(
                self.image_labels_filepath), f"Labels file {self.image_labels_filepath} does not exist"
            self.image_ir_filename = image_ir_filename
            self.image_ir_filepath = os.path.join(
                data_dir, self.image_ir_filename)
            assert os.path.exists(
                self.image_ir_filepath), f"IR file {self.image_ir_filepath} does not exist"

        # Resize ratio reduces the size of the image
        self.resize_ratio = resize_ratio

        # Load the meta data (mask, labels, and IR images)
        self.mask = load_image(
            self.image_mask_filepath,
            image_type='L',
            tensor_dtype=torch.bool,
            resize_ratio=self.resize_ratio,
            viz=self.viz,
        )
        if self.train:
            self.labels = load_image(
                self.image_labels_filepath,
                image_type='L',
                tensor_dtype=torch.bool,
                resize_ratio=self.resize_ratio,
                viz=self.viz,
            )

        # Assert that there are the correct amount of slices
        self.slice_depth = slice_depth

        # Dataset type determines precision of the data
        self.dataset_dtype = dataset_dtype

        # Load a single slice to get the width and height
        _slice = load_image(
            os.path.join(self.slices_dir, '00.tif'),
            image_type='F',
            tensor_dtype=dataset_dtype,
            resize_ratio=self.resize_ratio,
            viz=self.viz
        )
        self.fragment_size_x = _slice.shape[1]
        self.fragment_size_y = _slice.shape[2]

        # Load the slices (tif files) into one tensor, pre-allocate
        self.fragment = torch.zeros(
            self.slice_depth,
            self.fragment_size_x,
            self.fragment_size_y,
            dtype=self.dataset_dtype,
        )
        for i in tqdm(range(self.slice_depth)):
            _slice = load_image(
                os.path.join(self.slices_dir, f"{i:02d}.tif"),
                image_type='F',
                tensor_dtype=dataset_dtype,
                resize_ratio=self.resize_ratio,
            )
            self.fragment[i, :, :] = _slice
            
        # Get the mean and std of the fragment at the mask indices
        masked_fragment = torch.masked_select(self.fragment, self.mask)
        self.mean = torch.mean(masked_fragment, dim=-1)
        self.std = torch.std(masked_fragment, dim=-1)

        # Make sure the patch sizes are valid
        self.patch_size_x = patch_size_x
        self.patch_size_y = patch_size_y

        # Store the half-sizes of the patches
        self.patch_half_size_x = self.patch_size_x // 2
        self.patch_half_size_y = self.patch_size_y // 2
        pad_sizes = (
            # Padding in Y
            self.patch_half_size_y, self.patch_half_size_y,
            # Padding in X
            self.patch_half_size_x, self.patch_half_size_x,
            # No padding on z
            0, 0,
        )

        # Pad the fragment to make sure we can get patches from the edges
        self.fragment = torch.nn.functional.pad(
            self.fragment, pad_sizes, mode='constant', value=0.0)

        # Get indices where mask is 1
        self.mask_indices = torch.nonzero(self.mask.squeeze())

    def __len__(self):
        return self.mask_indices.shape[0]

    def __getitem__(self, index):
        # Get the x, y from the mask indices
        x, y = self.mask_indices[index]

        # Get the patch
        patch = self.fragment[
            :,
            x: x + self.patch_size_x,
            y: y + self.patch_size_y,
        ]

        # Label is going to be the label of the center voxel
        if self.train:
            label = self.labels[
                0,
                x,
                y,
            ]

        # Normalize the patch based on dataset
        patch = (patch - self.mean) / self.std

        if self.train:
            return patch, label
        else:
            # If we're not training, we don't have labels
            return patch

### Script

In [None]:
train_dir = '/kaggle/input/vesuvius-challenge-ink-detection/train/1'
eval_dir = '/kaggle/input/vesuvius-challenge-ink-detection/test'
slice_depth = 65
patch_size_x = 64
patch_size_y = 64
resize_ratio = 0.25
batch_size = 128
num_workers = 32
lr = 0.01
num_epochs = 2
threshold = 0.5

device = get_device()

# Load the model, try to fit on GPU
model = SimpleNet(
    slice_depth=slice_depth,
)
model = model.to(device)

# Training dataset
train_dataset = ClassificationDataset(
    # Directory containing the dataset
    train_dir,
    # Expected slices per fragment
    slice_depth=slice_depth,
    # Size of an individual patch
    patch_size_x=patch_size_x,
    patch_size_y=patch_size_y,
    # Image resize ratio
    resize_ratio=resize_ratio,
    # Training vs Testing mode
    train=True,
)

# Sampler for Train and Validation
train_sampler = RandomSampler(train_dataset)

# DataLoaders
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    # Shuffle does NOT work
    shuffle=False,
    sampler=train_sampler,
    num_workers=num_workers,
    # This will make it go faster if it is loaded into a GPU
    pin_memory=True,
)

# Create optimizers
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.BCEWithLogitsLoss()

# Train the model
best_valid_loss = 0
for epoch in range(num_epochs):
    log.info(f"Epoch {epoch + 1} of {num_epochs}")

    log.info(f"Training...")
    train_loss = 0
    for patch, label in tqdm(train_dataloader):
        optimizer.zero_grad()
        patch = patch.to(device)
        label = label.to(device).unsqueeze(1).to(torch.float32)
        pred = model(patch)
        loss = loss_fn(pred, label)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()  # Accumulate the training loss

    # Calculate the average training loss
    train_loss /= len(train_dataloader)
    log.info(f"Training loss: {train_loss:.4f}")

del train_dataloader, train_dataset, train_sampler

# Create submission file
submission_filepath = 'submission.csv'
with open(submission_filepath, 'w') as f:
    # Write header
    f.write("Id,Predicted\n")

# Baseline is to use image mask to create guess submission
for subtest_name in os.listdir(eval_dir):

    # Name of sub-directory inside test dir
    subtest_filepath = os.path.join(eval_dir, subtest_name)
    # Evaluation dataset
    eval_dataset = ClassificationDataset(
        # Directory containing the dataset
        subtest_filepath,
        # Expected slices per fragment
        slice_depth=slice_depth,
        # Size of an individual patch
        patch_size_x=patch_size_x,
        patch_size_y=patch_size_y,
        # Image resize ratio
        resize_ratio=resize_ratio,
        # Training vs Testing mode
        train=False,
    )

    # Make a blank prediction image
    pred_image = np.zeros(eval_dataset.mask.shape[1:], dtype=np.uint8)

    # DataLoaders
    eval_dataloader = DataLoader(
        eval_dataset,
        batch_size=1,
        # Shuffle does NOT work
        shuffle=False,
        sampler=SequentialSampler(eval_dataset),
        num_workers=num_workers,
        # This will make it go faster if it is loaded into a GPU
        pin_memory=True,
    )

    for i, patch in enumerate(tqdm(eval_dataloader)):
        patch = patch.to(device)
        pixel_index = eval_dataset.mask_indices[i]
        with torch.no_grad():
            pred = model(patch)
            pred = torch.sigmoid(pred)
        if pred > threshold:
            pred_image[pixel_index[0], pixel_index[1]] = 1
        
    starts_ix, lengths = rle(pred_image)
    inklabels_rle = " ".join(map(str, sum(zip(starts_ix, lengths), ())))
    with open(submission_filepath, 'a') as f:
        f.write(f"{subtest_name},{inklabels_rle}\n")