## Introduction
This notebook provides the code for inference on the 1st Herculaneum fragment based on a ResNet18 model trained on fragment 2 & 3.

Resnet Training notebook:
1. [Vesuvius Challenge - 3D ResNet Training](https://www.kaggle.com/code/samfc10/vesuvius-challenge-3d-resnet-training)


#### Base parameter:
1. Do inference on pre-trained ResNet18 model
2. 1 fold cross validation (Use 2,3 to train and 1 to val)
3. Inference on 192 x 192 x 16 windows

## Setup

In [None]:
import os,cv2
import gc
import sys
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
# from torch.cuda import amp
from torch.utils.data import Dataset, DataLoader
import PIL.Image as Image

sys.path.append("/kaggle/input/resnet3d")
from resnet3d import generate_model
import torch as tc

## Configuration

In [None]:
class CFG:
    # ============== paths =============
    comp_dir_path = '/kaggle/input/'
    comp_folder_name = 'vesuvius-challenge-ink-detection'
    comp_dataset_path = f'{comp_dir_path}{comp_folder_name}/'

    # ============== training config ========
    in_chans = 16   # The number of layers of the papyrus you want to read at both side
    prd_size= 192   # size of the crops 
    stride = prd_size // 8     # stride = 32
    batch_size = 24 #32
    seed = 42
    num_workers=2

## Create test dataset

In [None]:
def read_image(fragment_id):
    """
    return the 16 middle layers as a numpy array"""
    images = []
    mid = 65 // 2
    start = mid - CFG.in_chans // 2
    end = mid + CFG.in_chans // 2
    idxs = range(start, end)

    for i in tqdm(idxs):
        #the tif files 1 by 1
        image = cv2.imread(CFG.comp_dataset_path + f"{mode}/{fragment_id}/surface_volume/{i:02}.tif", 0)

        pad0 = (CFG.prd_size - image.shape[0] % CFG.prd_size)  # %: caluclates the rest according to a size for padding
        pad1 = (CFG.prd_size - image.shape[1] % CFG.prd_size)

        image = np.pad(image, [(0, pad0), (0, pad1)], constant_values=0)

        images.append(image)
    images = np.stack(images, axis=2)  # Stack all the images along new axis 
    
    return images

In [None]:
class CustomDataset(Dataset):
    def __init__(self, images, cfg, xys, labels=None):
        self.images = images
        self.cfg = cfg
        self.labels = labels
        self.xys=xys

    def __len__(self):
        # return len(self.xyxys)
        return len(self.images)

    def __getitem__(self, idx):
        # x1, y1, x2, y2 = self.xyxys[idx]
        image = self.images[idx]
        image=tc.from_numpy(image).permute(2,0,1).to(tc.float32)/255
        image = (image - 0.45)/0.225
        return image,self.xys[idx]


In [None]:
def make_test_dataset(fragment_id):
    test_images = read_image(fragment_id)
    
    # create lists of pixels for test dataset
    x1_list = list(range(0, test_images.shape[1]-CFG.prd_size+1, CFG.stride))  # define the stride of the image
    y1_list = list(range(0, test_images.shape[0]-CFG.prd_size+1, CFG.stride))
    
    test_images_list = []
    # list of all the x and y values of the test dataset
    xyxys = []
    for y1 in y1_list:
        for x1 in x1_list:
            #define the crops
            y2 = y1 + CFG.prd_size
            x2 = x1 + CFG.prd_size
            if np.all(test_images[y1:y2, x1:x2]==0):
                # avoid feeding zero images
                continue
            test_images_list.append(test_images[y1:y2, x1:x2])
            xyxys.append((x1, y1, x2, y2))
    xyxys = np.stack(xyxys)
            
    test_dataset = CustomDataset(test_images_list, CFG,xys=xyxys)
    
    test_loader = DataLoader(test_dataset,
                          batch_size=CFG.batch_size,
                          shuffle=False,
                          num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    
    return test_loader, xyxys

## Build the Resnet18 model

In [None]:
class Decoder(nn.Module):
    def __init__(self, encoder_dims, upscale):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(encoder_dims[i]+encoder_dims[i-1], encoder_dims[i-1], 3, 1, 1, bias=False),
                nn.BatchNorm2d(encoder_dims[i-1]),
                nn.ReLU(inplace=True)
            ) for i in range(1, len(encoder_dims))])

        self.logit = nn.Conv2d(encoder_dims[0], 1, 1, 1, 0)
        self.up = nn.Upsample(scale_factor=upscale, mode="bilinear")

    def forward(self, feature_maps):
        for i in range(len(feature_maps)-1, 0, -1):
            f_up = F.interpolate(feature_maps[i], scale_factor=2, mode="bilinear")
            f = torch.cat([feature_maps[i-1], f_up], dim=1)
            f_down = self.convs[i-1](f)
            feature_maps[i-1] = f_down

        x = self.logit(feature_maps[0])
        mask = self.up(x)
        return mask


class SegModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = generate_model(model_depth=18, n_input_channels=1)
        self.decoder = Decoder(encoder_dims=[64, 128, 256, 512], upscale=4)
        
    def forward(self, x):
        if x.ndim==4:
            x=x[:,None]
            
        feat_maps = self.encoder(x)
        feat_maps_pooled = [torch.mean(f, dim=2) for f in feat_maps]
        pred_mask = self.decoder(feat_maps_pooled)
        return pred_mask
    
    def load_pretrained_weights(self, state_dict):
        # Convert 3 channel weights to single channel
        # ref - https://timm.fast.ai/models#Case-1:-When-the-number-of-input-channels-is-1
        conv1_weight = state_dict['conv1.weight']
        state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
        print(self.encoder.load_state_dict(state_dict, strict=False))

#### Load the trained weights

In [None]:
in_submission=True
IS_DEBUG = False # True False
mode = 'train' if IS_DEBUG else 'test'
TH = 0.55  # The treshold for ink or no ink

#take the test datasets
if mode == 'test':
    fragment_ids = sorted(os.listdir(CFG.comp_dataset_path + mode))
else:
    fragment_ids = [1] 

model = SegModel()
model = nn.DataParallel(model, device_ids=[0, 1])
model = model.cuda()#.eval() #sending the model to a cuda device
#model.load_state_dict(tc.load("/kaggle/input/3d-resnet-baseline-inference-model-data/resnet3d-34_3d_seg_epoch_14.pth"))
model.module.load_state_dict(tc.load("/kaggle/input/resnet18/resnet18_3d_seg_32_0.58.pt"))
model.training

### To Do: Add better visualizations about the model

### Helper functions for inference

#### 1. Run length encoding of the predictions
The rle (run lenght encoding) function is provided to submit the predictions the ofrmat of competition requirements

In [None]:

# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle(img):
    '''
    img: numpy array, 1 = mask, 0 = background
    Returns run length encoding as string formated
    '''
    ## DECIDE OPTIMAL threshold
    #thr = 0.5
    
    pixels = img.flatten()
    #pixels = (pixels >= thr).astype(int)
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1  # returns indices where consecutive elements are different
    print('runs', runs.shape)
    
    runs[1::2] -= runs[::2]  #subtract the values of the even indices from the corresponding one of the odd indices and write the result 
    return ' '.join(str(x) for x in runs)

#### 2. TTA

In [None]:
def TTA(x:tc.Tensor,model:nn.Module):
    #x.shape=(batch,c,h,w)
    shape=x.shape
    #print("shape of the tensor that is predicted" ,shape)
    # rotate the tensor by (k*90 degrees) from the -2 to -1 axis 
    x=[x,*[tc.rot90(x,k=i,dims=(-2,-1)) for i in range(1,4)]]
    
    #print("shape of the augmented tensor that is predicted" ,x[0].shape)
    x=tc.cat(x,dim=0)  #concatenate in given dimension 
    
    #print("shape after concatenation" ,x[0].shape)
    
    # make a prediction
    x=model(x)
    x=torch.sigmoid(x) # create output between 0 and 1
    #rerotate the augmented data in the original position ? 
    x=x.reshape(4,shape[0],*shape[2:])
    x=[tc.rot90(x[i],k=-i,dims=(-2,-1)) for i in range(4)]
    x=tc.stack(x,dim=0)
    return x.mean(0)

#### 3. L1/ Hessian Denoising of the predictions
This idea was proposed during the competition by Brett Olsen:\
[Improving performance with L1/Hessian denoising](https://www.kaggle.com/code/brettolsen/improving-performance-with-l1-hessian-denoising)\

The approach is to exploit known properties of the ink distribution, in particular that:
- The **ink is sparse**, most regions will not contain ink. An **L1 regularization term** will penalyze noisy data
- The **ink** is continuous, a single pixel is more likely to contain ink when it is newt to another pixel containing ink. A **Hessian matrix term** is used to penalize strongly variable values of ink/no-ink 

In [None]:
import cupy as cp
xp = cp

delta_lookup = {
    "xx": xp.array([[1, -2, 1]], dtype=float),
    "yy": xp.array([[1], [-2], [1]], dtype=float),
    "xy": xp.array([[1, -1], [-1, 1]], dtype=float),
}

def operate_derivative(img_shape, pair):
    assert len(img_shape) == 2
    delta = delta_lookup[pair]
    fft = xp.fft.fftn(delta, img_shape)
    return fft * xp.conj(fft)

def soft_threshold(vector, threshold):
    return xp.sign(vector) * xp.maximum(xp.abs(vector) - threshold, 0)

def back_diff(input_image, dim):
    assert dim in (0, 1)
    r, n = xp.shape(input_image)
    size = xp.array((r, n))
    position = xp.zeros(2, dtype=int)
    temp1 = xp.zeros((r+1, n+1), dtype=float)
    temp2 = xp.zeros((r+1, n+1), dtype=float)
    
    temp1[position[0]:size[0], position[1]:size[1]] = input_image
    temp2[position[0]:size[0], position[1]:size[1]] = input_image
    
    size[dim] += 1
    position[dim] += 1
    temp2[position[0]:size[0], position[1]:size[1]] = input_image
    temp1 -= temp2
    size[dim] -= 1
    return temp1[0:size[0], 0:size[1]]

def forward_diff(input_image, dim):
    assert dim in (0, 1)
    r, n = xp.shape(input_image)
    size = xp.array((r, n))
    position = xp.zeros(2, dtype=int)
    temp1 = xp.zeros((r+1, n+1), dtype=float)
    temp2 = xp.zeros((r+1, n+1), dtype=float)
        
    size[dim] += 1
    position[dim] += 1

    temp1[position[0]:size[0], position[1]:size[1]] = input_image
    temp2[position[0]:size[0], position[1]:size[1]] = input_image
    
    size[dim] -= 1
    temp2[0:size[0], 0:size[1]] = input_image
    temp1 -= temp2
    size[dim] += 1
    return -temp1[position[0]:size[0], position[1]:size[1]]

def iter_deriv(input_image, b, scale, mu, dim1, dim2):
    g = back_diff(forward_diff(input_image, dim1), dim2)
    d = soft_threshold(g + b, 1 / mu)
    b = b + (g - d)
    L = scale * back_diff(forward_diff(d - b, dim2), dim1)
    return L, b

def iter_xx(*args):
    return iter_deriv(*args, dim1=1, dim2=1)

def iter_yy(*args):
    return iter_deriv(*args, dim1=0, dim2=0)

def iter_xy(*args):
    return iter_deriv(*args, dim1=0, dim2=1)

def iter_sparse(input_image, bsparse, scale, mu):
    d = soft_threshold(input_image + bsparse, 1 / mu)
    bsparse = bsparse + (input_image - d)
    Lsparse = scale * (d - bsparse)
    return Lsparse, bsparse

def denoise_image(input_image, iter_num=100, fidelity=150, sparsity_scale=10, continuity_scale=0.5, mu=1):
    image_size = xp.shape(input_image)
    #print("Initialize denoising")
    norm_array = (
        operate_derivative(image_size, "xx") + 
        operate_derivative(image_size, "yy") + 
        2 * operate_derivative(image_size, "xy")
    )
    norm_array += (fidelity / mu) + sparsity_scale ** 2
    b_arrays = {
        "xx": xp.zeros(image_size, dtype=float),
        "yy": xp.zeros(image_size, dtype=float),
        "xy": xp.zeros(image_size, dtype=float),
        "L1": xp.zeros(image_size, dtype=float),
    }
    g_update = xp.multiply(fidelity / mu, input_image)
    for i in tqdm(range(iter_num), total=iter_num):
        #print(f"Starting iteration {i+1}")
        g_update = xp.fft.fftn(g_update)
        if i == 0:
            g = xp.fft.ifftn(g_update / (fidelity / mu)).real
        else:
            g = xp.fft.ifftn(xp.divide(g_update, norm_array)).real
        g_update = xp.multiply((fidelity / mu), input_image)
        
        #print("XX update")
        L, b_arrays["xx"] = iter_xx(g, b_arrays["xx"], continuity_scale, mu)
        g_update += L
        
        #print("YY update")
        L, b_arrays["yy"] = iter_yy(g, b_arrays["yy"], continuity_scale, mu)
        g_update += L
        
        #print("XY update")
        L, b_arrays["xy"] = iter_xy(g, b_arrays["xy"], 2 * continuity_scale, mu)
        g_update += L
        
        #print("L1 update")
        L, b_arrays["L1"] = iter_sparse(g, b_arrays["L1"], sparsity_scale, mu)
        g_update += L
        
    g_update = xp.fft.fftn(g_update)
    g = xp.fft.ifftn(xp.divide(g_update, norm_array)).real
    
    g[g < 0] = 0
    g -= g.min()
    g /= g.max()
    return g

## Inference

In [None]:
results = []
for fragment_id in fragment_ids:
    print("Start Inference on Fragment", fragment_id)
    if not in_submission:
        break
    
    print("Load test dataset")
    test_loader, xyxys = make_test_dataset(fragment_id)
    print("Total number of batches:",len(test_loader))
    
    #mask that says where the fragment is 
    binary_mask = cv2.imread(CFG.comp_dataset_path + f"{mode}/{fragment_id}/mask.png", 0)
    
    #cv2.imshow("binary_mask", binary_mask)
    binary_mask = (binary_mask / 255).astype(int) # --> Change to 0 or 1
    
    ori_h = binary_mask.shape[0]
    ori_w = binary_mask.shape[1]

    # add the padding
    pad0 = (CFG.prd_size - binary_mask.shape[0] % CFG.prd_size)
    pad1 = (CFG.prd_size - binary_mask.shape[1] % CFG.prd_size)

    # pad the binary mask
    binary_mask = np.pad(binary_mask, [(0, pad0), (0, pad1)], constant_values=0)
    
    # create predictions
    print("start predictions")
    mask_pred = np.zeros(binary_mask.shape)
    mask_count = np.zeros(binary_mask.shape)
    for step, (images,xys) in tqdm(enumerate(test_loader), total=len(test_loader)):
        # 1 test loader = batch of 24 image crops --> in total: + than 1000 batches!
        images = images.cuda()
        batch_size = images.size(0)

        with torch.no_grad():  
            # Do the inference
            y_preds=TTA(images,model)
            #print(y_preds)
            
        
        for k, (x1, y1, x2, y2) in enumerate(xys):
            # Add the inference to the mask
            mask_pred[y1:y2, x1:x2] += y_preds[k].squeeze(0).cpu().numpy()
            mask_count[y1:y2, x1:x2] += 1
        
    mask_pred /= (mask_count+1e-7)
    
    ##### NICE PLOTS #####

    fig, axes = plt.subplots(1, 3, figsize=(15, 8))
    axes[0].imshow(mask_count)
    axes[1].imshow(mask_pred.copy())

    mask_pred=xp.array(mask_pred)
    #mask_pred=denoise_image(mask_pred, iter_num=250)
    mask_pred=mask_pred.get()

    mask_pred = mask_pred[:ori_h, :ori_w]
    binary_mask = binary_mask[:ori_h, :ori_w]

    mask_pred_1 = (mask_pred >= TH).astype(np.uint8)
    mask_pred_1 =mask_pred_1.astype(int)
    mask_pred_1 *= binary_mask

    axes[2].imshow(mask_pred_1)
    plt.show()
    
    inklabels_rle = rle(mask_pred_1)
    results.append((fragment_id, inklabels_rle))

In [None]:
#!apt-get -qq install -y graphviz && pip install -q pydot
#!pip install torchviz
#!pip install graphviz

In [None]:
#make_dot(mask_pred, params = dict())

In [None]:
# clean ntb memory

del mask_pred, mask_count
del test_loader
    
gc.collect()
torch.cuda.empty_cache()
plt.clf()
fig.clear()
plt.close(fig)

### Add in trainign ntb !

In [None]:
# get metrics if in training fragments
def metric_to_text(ink, label):
    text = []

    #calculate bce
    p = ink.reshape(-1)
    t = label.reshape(-1)
    pos = np.log(np.clip(p,1e-7,1))
    neg = np.log(np.clip(1-p,1e-7,1))
    bce = -(t*pos +(1-t)*neg).mean()
    text.append(f'bce={bce:0.5f}')


    #print(f'{threshold:0.1f}, {precision:0.3f}, {recall:0.3f}, {fpr:0.3f},  {dice:0.3f},  {score:0.3f}')
    text.append('th   prec   recall   fpr   dice   score')
    text.append('---------------------------------------')
    for threshold in [0.25, 0.30, 0.35, 0.4,0.45, 0.5,0.55, 0.6,0.65, 0.7]:
        p = ink.reshape(-1) # reshape predictions
        t = label.reshape(-1)  #rechape true labels
        p = (p > threshold).astype(np.float32)  # make predictions as a 0 or 1
        t = (t > 0.5).astype(np.float32)

        tp = p * t # true positives
        precision = tp.sum() / (p.sum() + 0.0001)
        recall = tp.sum() / t.sum()

        fp = p * (1 - t)  # false positives
        fpr = fp.sum() / (1 - t).sum()

        beta = 0.5
        #  0.2*1/recall + 0.8*1/prec
        score = beta * beta / (1 + beta * beta) * 1 / recall + 1 / (1 + beta * beta) * 1 / precision
        score = 1 / score

        dice = 2 * tp.sum() / (p.sum() + t.sum())

        # print(fold, threshold, precision, recall, fpr,  score)
        text.append( f'{threshold:0.2f}, {precision:0.3f}, {recall:0.3f}, {fpr:0.3f},  {dice:0.3f},  {score:0.3f}')
    text = '\n'.join(text)
    return text

def load_labels(fragment_id):
    img = Image.open(f"{CFG.comp_dataset_path}/{mode}/{fragment_id}/inklabels.png")
    return np.array(img)

In [None]:
if mode == "train":
    labels = load_labels("1")
    pad0 = (CFG.prd_size - ori_h % CFG.prd_size)
    pad1 = (CFG.prd_size - ori_w % CFG.prd_size)

    labels = np.pad(labels, [(0, pad0), (0, pad1)], constant_values=0)
    print("true labels shape", labels.shape)
    print("predicted mask shape", mask_pred.shape)
    
    text = metric_to_text(mask_pred, labels)
    print(text)

## submission

In [None]:
! cp /kaggle/input/vesuvius-challenge-ink-detection/sample_submission.csv submission.csv
if in_submission:
    sub = pd.DataFrame(results, columns=['Id', 'Predicted'])
    #sub
    sample_sub = pd.read_csv(CFG.comp_dataset_path + 'sample_submission.csv')
    sample_sub = pd.merge(sample_sub[['Id']], sub, on='Id', how='left')
    #sample_sub
    sample_sub.to_csv("submission.csv", index=False)
    print("ok")