### Dependencies

In [None]:
import glob
import os
import random
from copy import deepcopy
import time
import subprocess
import gc
from typing import List

import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from PIL import Image
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, SubsetRandomSampler
from torchvision.transforms import ToTensor
from tqdm import tqdm


# Notebook will only run for this amount of time
time_start = time.time()
time_elapsed = 0
time_train_max = 8 * 60 * 60 # 8 hours

### Model

In [None]:
class SimpleNet(nn.Module):
    def __init__(
        self,
        slice_depth: int = 65,
    ):
        super().__init__()
        self.conv1 = nn.Conv2d(slice_depth, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.LazyLinear(120)
        self.fc2 = nn.LazyLinear(84)
        self.fc3 = nn.LazyLinear(1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

### Utils

In [None]:
def rle(img, threshold=0.5):
    # TODO: Histogram of image to see where threshold should be
    flat_img = img.flatten()
    flat_img = np.where(flat_img > threshold, 1, 0).astype(np.uint8)
    starts = np.array((flat_img[:-1] == 0) & (flat_img[1:] == 1))
    ends = np.array((flat_img[:-1] == 1) & (flat_img[1:] == 0))
    starts_ix = np.where(starts)[0] + 2
    ends_ix = np.where(ends)[0] + 2
    lengths = ends_ix - starts_ix
    return starts_ix, lengths

def get_device():
    if torch.cuda.is_available():
        print("Using GPU")
        return torch.device("cuda")
    else:
        print("Using CPU")
        return torch.device("cpu")
    
def get_gpu_memory():
    result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.free', '--format=csv,nounits,noheader'], stdout=subprocess.PIPE, text=True)
    gpu_memory = [tuple(map(int, line.split(','))) for line in result.stdout.strip().split('\n')]
    for i, (used, free) in enumerate(gpu_memory):
        print(f"GPU {i}: Memory Used: {used} MiB | Memory Available: {free} MiB")
        
def clear_gpu_memory():
    if torch.cuda.is_available():
        print('Clearing GPU memory')
        torch.cuda.empty_cache()
        gc.collect()

### Dataset

In [None]:
class CurriculumDataset(data.Dataset):

    def __init__(
        self,
        # Directory containing the datasets
        data_dirs: List[str],
        # Expected slices per fragment
        slice_depth: int = 65,
        # Size of an individual patch
        patch_size_x: int = 1028,
        patch_size_y: int = 256,
        # Image resize ratio
        resize_ratio: float = 1.0,
        # Training vs Testing mode
        train: bool = True,
        # Filenames of the images we'll use
        image_mask_filename='mask.png',
        image_labels_filename='inklabels.png',
        slices_dir_filename='surface_volume',
    ):
        print(f"Initializing CurriculumDataset")
        # Train mode also loads the labels
        self.train = train
        # Resize ratio reduces the size of the image
        self.resize_ratio = resize_ratio
        # Data will be B x slice_depth x patch_size_x x patch_size_y
        self.patch_size_x = patch_size_x
        self.patch_size_y = patch_size_y
        self.slice_depth = slice_depth
        # Potential N datasets
        self.data_dirs = []
        for data_dir in data_dirs:
            assert os.path.exists(data_dir), f"Data directory {data_dir} does not exist"
            self.data_dirs.append(data_dir)
            # Open Mask image
            _image_mask_filepath = os.path.join(data_dir, image_mask_filename)
            _mask_img = Image.open(_image_mask_filepath).convert("1")
            # Get original size and resized size
            original_size = _mask_img.size
            resized_size = (
                int(original_size[0] * self.resize_ratio),
                int(original_size[1] * self.resize_ratio),
            )
            # Resize the mask
            print(f"Mask original size: {original_size}")
            _mask_img = _mask_img.resize(resized_size, resample=Image.BILINEAR)
            print(f"Mask resized size: {_mask_img.size}")
            _mask = torch.from_numpy(np.array(_mask_img)).to(torch.bool)
            print(f"Mask tensor shape: {_mask.shape}")
            print(f"Mask tensor dtype: {_mask.dtype}")
            if train:
                _image_labels_filepath = os.path.join(data_dir, image_labels_filename)
                _labels_img = Image.open(_image_labels_filepath).convert("1")
                print(f"Labels original size: {original_size}")
                _labels_img = _labels_img.resize(resized_size, resample=Image.BILINEAR)
                print(f"Labels resized size: {_labels_img.size}")
                _labels = torch.from_numpy(np.array(_labels_img)).to(torch.bool)
                print(f"Labels tensor shape: {_labels.shape}")
                print(f"Labels tensor dtype: {_labels.dtype}")
            # Pre-allocate the entire fragment
            _fragment = torch.zeros((
                    self.slice_depth,
                    resized_size[1],
                    resized_size[0],
                ), dtype=torch.float32
            )
            print(f"Fragment tensor shape: {_fragment.shape}")
            print(f"Fragment tensor dtype: {_fragment.dtype}")
            # Open up slices
            _slice_dir = os.path.join(data_dir, slices_dir_filename)
            for i in tqdm(range(self.slice_depth)):
                _slice_filepath = os.path.join(_slice_dir, f"{i:02d}.tif")
                _slice_img = Image.open(_slice_filepath).convert('F')
                print(f"Slice original size: {original_size}")
                _slice_img = _slice_img.resize(resized_size, resample=Image.BILINEAR)
                print(f"Slice resized size: {_slice_img.size}")
                _slice = torch.from_numpy(np.array(_slice_img)/65535.0)
                print(f"Slice tensor shape: {_slice.shape}")
                print(f"Slice tensor dtype: {_slice.dtype}")
                _fragment[i, :, :] = _slice

            print(f"Fragment tensor shape: {_fragment.shape}")
            print(f"Fragment tensor dtype: {_fragment.dtype}")
            print(f"Fragment tensor min: {_fragment.min()}")
            print(f"Fragment tensor max: {_fragment.max()}")
            print(f"Fragment tensor mean: {_fragment.mean()}")
            print(f"Fragment tensor std: {_fragment.std()}")

            # Normalize the fragment only on mask indices
            _mask_indices = torch.nonzero(_mask)
            _masked_fragment = _fragment[_mask]
            _mean = torch.mean(_masked_fragment)
            _std = torch.std(_masked_fragment)

        # Pad the fragment to make sure we can get patches from the edges
        self.fragment = torch.nn.functional.pad(
            self.fragment, pad_sizes, mode='constant', value=0.0)

        # Get indices where mask is 1
        self.mask_indices = torch.nonzero(self.mask.squeeze())

    def __len__(self):
        return self.mask_indices.shape[0]

    def __getitem__(self, index):
        # Get the x, y from the mask indices
        x, y = self.mask_indices[index]

        # Get the patch
        patch = self.fragment[
            :,
            x: x + self.patch_size_x,
            y: y + self.patch_size_y,
        ]

        # Label is going to be the label of the center voxel
        if self.train:
            label = self.labels[
                0,
                x,
                y,
            ]

        # Normalize the patch based on dataset
        patch = (patch - self.mean) / self.std

        if self.train:
            return patch, label
        else:
            # If we're not training, we don't have labels
            return patch

train_dir = '/home/tren/dev/ashenvenus/data/train/1'
train_dataset = CurriculumDataset(
    # Directory containing the dataset
    data_dirs = [
        train_dir,
    ],
    # Expected slices per fragment
    slice_depth=65,
    # Size of an individual patch
    patch_size_x=256,
    patch_size_y=64,
    # Image resize ratio
    resize_ratio=0.1,
    # Training vs Testing mode
    train=True,
)

### Script

In [None]:
# train_dir = '/kaggle/input/vesuvius-challenge-ink-detection/train/1'
# eval_dir = '/kaggle/input/vesuvius-challenge-ink-detection/test'
train_dir = '/home/tren/dev/ashenvenus/data/train/1'
eval_dir = '/home/tren/dev/ashenvenus/data/test'
slice_depth = 65
patch_size_x = 256
patch_size_y = 64
resize_ratio = 0.25
batch_size = 512
num_workers = 1
lr = 0.0002
num_epochs = 2
threshold = 0.5
train_dataset_size = 1000

device = get_device()

model = None
test_batch = None

print("START")
clear_gpu_memory()
get_gpu_memory()

# Load the model, try to fit on GPU
model = SimpleNet(
    slice_depth=slice_depth,
)
model = model.to(device)

print("LOADED MODEL")
get_gpu_memory()

# # Empty array to test batch size
# for i in range(10):
#     test_batch = torch.zeros(batch_size, slice_depth, patch_size_x, patch_size_y)
#     test_batch = test_batch.to(device)
#     print(f"LOADED BATCH {i}")
#     get_gpu_memory()
#     with torch.no_grad():
#         model(test_batch)

In [None]:
# Training dataset
train_dataset = ClassificationDataset(
    # Directory containing the dataset
    train_dir,
    # Expected slices per fragment
    slice_depth=slice_depth,
    # Size of an individual patch
    patch_size_x=patch_size_x,
    patch_size_y=patch_size_y,
    # Image resize ratio
    resize_ratio=resize_ratio,
    # Training vs Testing mode
    train=True,
)
total_dataset_size = len(train_dataset)

print("LOADED DATASET")
get_gpu_memory()

train_idx = [i for i in range(total_dataset_size)]
print(f"Raw train dataset size: {len(train_idx)}")
np.random.shuffle(train_idx)
train_idx = train_idx[:train_dataset_size]
print(f"Reduced train dataset size: {len(train_idx)}")
train_sampler = SubsetRandomSampler(train_idx)

# DataLoaders
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    # Shuffle does NOT work
    shuffle=False,
    sampler=train_sampler,
    num_workers=num_workers,
    # This will make it go faster if it is loaded into a GPU
    pin_memory=True,
)

# Create optimizers
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.BCEWithLogitsLoss()

# Train the model
best_valid_loss = 0
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1} of {num_epochs}")

    print(f"Training...")
    train_loss = 0
    for patch, label in tqdm(train_dataloader):
        optimizer.zero_grad()
        patch = patch.to(device)
        label = label.to(device).unsqueeze(1).to(torch.float32)

        print(f"TRAINING ON BATCH")
        get_gpu_memory()

        pred = model(patch)
        loss = loss_fn(pred, label)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()  # Accumulate the training loss

        time_elapsed = time.time() - time_start
        if time_elapsed > time_train_max:
            print('Time limit reached, stopping batch')
            break

    if time_elapsed > time_train_max:
        print('Time limit reached, stopping epoch')
        break

    # Calculate the average training loss
    train_loss /= len(train_dataloader)
    print(f"Training loss: {train_loss:.4f}")

del train_dataloader, train_dataset, train_sampler
clear_gpu_memory()

# Create submission file
submission_filepath = 'submission.csv'
with open(submission_filepath, 'w') as f:
    # Write header
    f.write("Id,Predicted\n")

# Baseline is to use image mask to create guess submission
for subtest_name in os.listdir(eval_dir):

    # Name of sub-directory inside test dir
    subtest_filepath = os.path.join(eval_dir, subtest_name)
    # Evaluation dataset
    eval_dataset = ClassificationDataset(
        # Directory containing the dataset
        subtest_filepath,
        # Expected slices per fragment
        slice_depth=slice_depth,
        # Size of an individual patch
        patch_size_x=patch_size_x,
        patch_size_y=patch_size_y,
        # Image resize ratio
        resize_ratio=resize_ratio,
        # Training vs Testing mode
        train=False,
    )

    # Make a blank prediction image
    pred_image = np.zeros(eval_dataset.mask.shape[1:], dtype=np.uint8)

    # DataLoaders
    eval_dataloader = DataLoader(
        eval_dataset,
        batch_size=1,
        # Shuffle does NOT work
        shuffle=False,
        sampler=SequentialSampler(eval_dataset),
        num_workers=num_workers,
        # This will make it go faster if it is loaded into a GPU
        pin_memory=True,
    )

    for i, patch in enumerate(tqdm(eval_dataloader)):
        patch = patch.to(device)
        pixel_index = eval_dataset.mask_indices[i]
        with torch.no_grad():
            pred = model(patch)
            pred = torch.sigmoid(pred)
        if pred > threshold:
            pred_image[pixel_index[0], pixel_index[1]] = 1
            
    # Resize pred_image to original size
    _img = Image.fromarray(pred_image)
    _img = _img.resize((
        eval_dataset.original_image_size_x,
        eval_dataset.original_image_size_y,
    ))
    pred_image = np.array(_img)
    
        
    starts_ix, lengths = rle(pred_image)
    inklabels_rle = " ".join(map(str, sum(zip(starts_ix, lengths), ())))
    with open(submission_filepath, 'a') as f:
        f.write(f"{subtest_name},{inklabels_rle}\n")