In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# %pip install deepspeed

In [None]:
# import shutil
# shutil.rmtree("/kaggle/working/lightning_logs")

In [None]:
import pytorch_lightning as lt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import tqdm.notebook as notebook
from torch.utils.data import Dataset, DataLoader
import PIL.Image as Image
import io
import gc
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint,StochasticWeightAveraging
from torchmetrics.classification import BinaryFBetaScore as Fbeta
# import deepspeed

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Inititiation Block
Z_START = 25
Z_DIM = 10
BUFFER = 30
BATCH_SIZE = 128

In [None]:
# client = storage.Client()
# bucket = client.get_bucket('dlproject-7643')
# /kaggle/input/vesuvius-challenge-ink-detection/train
def get_img_bucket(text):
#     blob = bucket.blob(text)
#     img_bytes = io.BytesIO(blob.download_as_bytes())
    img = Image.open(text)
    return img

def load_img(x):
    PREFIX = "/kaggle/input/vesuvius-challenge-ink-detection/train/"+x+"/surface_volume/"
    # Use glob.glob to get a list of all the file paths
    file_paths = [PREFIX + "{:02d}.tif".format(i) for i in range(Z_START, Z_START + Z_DIM)]

    # Load the images into a list of numpy arrays
    images = [np.array(get_img_bucket(filename), dtype=np.float32)/65535.0 for filename in notebook.tqdm(file_paths)]
    data = np.array(images)
    gold = np.array(get_img_bucket("/kaggle/input/vesuvius-challenge-ink-detection/train/"+x+"/inklabels.png"))
    mask=np.array(get_img_bucket("/kaggle/input/vesuvius-challenge-ink-detection/train/"+x+"/mask.png"))
    return data,gold,mask

def pad_array(x,Buffer=BUFFER):
    x=  np.pad(x, ((0, 0), (Buffer, Buffer), (Buffer, Buffer)), mode='constant', constant_values=0)
    return x

def get_pixels(mask,rect = [2500,3500,1500,2500]):
    inside_rect = np.zeros(mask.shape, dtype=bool)
    inside_rect[rect[0]:rect[1], rect[2]:rect[3]] = True

    outside_rect = np.ones(mask.shape, dtype=bool)
    outside_rect[rect[0]:rect[1], rect[2]:rect[3]] = False
    
    inside_rect = torch.from_numpy(np.pad(inside_rect,((BUFFER, BUFFER), (BUFFER, BUFFER)), mode='constant', constant_values=0)).float()
    outside_rect = torch.from_numpy(np.pad(outside_rect,((BUFFER, BUFFER), (BUFFER, BUFFER)), mode='constant', constant_values=0)).float()
    val_pixels = torch.argwhere(inside_rect)
    train_pixels = torch.argwhere(outside_rect)
    return train_pixels,val_pixels

def pad_datasets(data,mask,gold):
    mask = torch.from_numpy(np.pad(mask,((BUFFER, BUFFER), (BUFFER, BUFFER)), mode='constant', constant_values=0)).float()
    gold = torch.from_numpy(np.pad(gold,((BUFFER, BUFFER), (BUFFER, BUFFER)), mode='constant', constant_values=0)).float()
    data = pad_array(data)
    data = torch.stack([torch.from_numpy(image) for image in data], dim=0)
    return data,mask,gold

In [None]:
input_data ={}
gold_labels ={}
t_pixels = []
v_pixels = []

for i in [1,3]:
    data,gold,mask = load_img(str(i))
    train_pixels,val_pixels = get_pixels(mask)
    data,mask,gold = pad_datasets(data,mask,gold)
    t_pixels.append(torch.cat((train_pixels,i*torch.ones((train_pixels.shape[0],1))),dim=1).int())
    v_pixels.append(torch.cat((val_pixels,i*torch.ones((val_pixels.shape[0],1))),dim=1).int())
    gold_labels[i] = gold
    input_data[i] = data

    del data
    del gold

    gc.collect()
    
train_pixels = torch.cat(t_pixels,dim=0)
val_pixels = torch.cat(v_pixels,dim=0)

del t_pixels
del v_pixels
gc.collect()

In [None]:
class SubvolumeDataset(Dataset):
    def __init__(self, image_stack, label, pixels):
        self.image_stack = image_stack
        self.label = label
        self.pixels = pixels

    def __len__(self):
        return len(self.pixels)
    
    def __getitem__(self, index):
        y,x,key = self.pixels[index]
        
        try:
            subvolume = self.image_stack[int(key.item())][:, y-BUFFER:y+BUFFER+1, x-BUFFER:x+BUFFER+1].view(1, Z_DIM, BUFFER*2+1, BUFFER*2+1)
        except:
            print(self.image_stack[key.item()][:, y-BUFFER:y+BUFFER+1, x-BUFFER:x+BUFFER+1].shape)
            print(x,y,BUFFER)
            subvolume = self.image_stack[key.item()][:, y-BUFFER:y+BUFFER+1, x-BUFFER:x+BUFFER+1].view(1, Z_DIM, BUFFER*2+1, BUFFER*2+1)
        
        inklabel = self.label[key.item()][y, x].view(1)
        
        return subvolume, inklabel

In [None]:
train_dataset = SubvolumeDataset(input_data,gold_labels,train_pixels)
train_loader =  DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataset = SubvolumeDataset(input_data,gold_labels,val_pixels)
val_loader =  DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
class Litmodel(lt.LightningModule):
    def __init__(self,model,**kwargs):
        super().__init__(**kwargs)
        self.model = model
        self.save_hyperparameters()
    
    def training_step(self,batch,batch_idx):
        x,y = batch

        z = self.model.forward(x.to(DEVICE))
#         print(z.shape)

        y = F.one_hot(y.squeeze().long(),num_classes = 2).float().to(DEVICE)
        loss = F.binary_cross_entropy_with_logits(z,y)
        self.log("batch_idx",batch_idx)
        return loss
    
#     def validation_step(self,batch,batch_idx):
#         x,y = batch

#         z = self.model.forward(x.to(DEVICE))
#         y = F.one_hot(y.squeeze().long(),num_classes = 2).float().to(DEVICE)
#         loss = F.binary_cross_entropy_with_logits(z,y)
# #         f05 = Fbeta(num_classes =2,beta=0.5).to(DEVICE)
# #         score = f05(z,y).to(DEVICE)
#         self.log('val_loss', loss)
# #         self.log("F0.5 score",score)
#         return loss

    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),lr = 1e-4)
        return optimizer
        
    

In [None]:
class InkDetection(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv3d(in_channels = 1,out_channels = 16, kernel_size = 3 , padding = 1)
        self.act1 = nn.LeakyReLU()
        
        self.BN1 = nn.BatchNorm3d(16)
        self.conv2 =  nn.Conv3d(in_channels = 16,  out_channels =32 ,kernel_size = 3, padding =1)
        self.BN2 = nn.BatchNorm3d(32)
        self.conv3 =  nn.Conv3d(in_channels = 32,  out_channels =64 ,kernel_size = 3, padding =1)
        self.BN3 = nn.BatchNorm3d(64)
        
        self.pool = nn.MaxPool3d(2,2)
        
        self.ffn = nn.Sequential(nn.LazyLinear(128),
                                 nn.ReLU(),
                                 nn.Dropout(0.2),
                                 nn.Linear(128,128),
                                 nn.ReLU(),
                                 nn.Dropout(0.2),
                                 nn.Linear(128,2)
        )
        self.final = nn.Sigmoid()
        
    def forward(self,X):
        
        X= self.conv1(X)
        X= self.act1 (X)
        X=self.BN1(X)
        X=self.pool(X)
        X= self.conv2(X)
        X= self.act1 (X)
        X=self.BN2(X)
        X=self.pool(X)
        X= self.conv3(X)
        X= self.act1 (X)
        X=self.BN3(X)
        X=self.pool(X)

        X=X.flatten(start_dim=1)

        #linear layers
        X=self.ffn(X)
#         X =self.final(X)


        return X   

In [None]:
# torch.save(model,"model.pt")
# torch.save(model.state_dict(),"model_weights.pt")

In [None]:
class MyEarlyStopping(EarlyStopping):
    def on_validation_end(self,trainer,pl_module):
        pass
    def on_train_batch_start(batch,batch_idx):
        if batch_idx>=10000:
            return -1

In [None]:
os.listdir("/kaggle/working/lightning_logs/version_0/checkpoints")

In [None]:
early_stopping =  MyEarlyStopping(monitor = "val_loss",mode="min" )
lt_model = Litmodel(InkDetection())
#trainer = lt.Trainer(strategy="auto", accelerator="gpu", devices=1, precision=16, limit_train_batches = 10000, max_epochs =5 ,accumulate_grad_batches = 10,resume_from_checkpoint = "give checkpoint here" )

In [None]:
# trainer.fit(model=lt_model, train_dataloaders=train_loader)

In [None]:
checkpoint = torch.load("/kaggle/working/lightning_logs/version_0/checkpoints/epoch=4-step=5000.ckpt")
checkpoint["state_dict"] = {key[6:]:value for key,value in checkpoint["state_dict"].items()}
trained_model =  InkDetection()
trained_model.load_state_dict(checkpoint["state_dict"])

In [None]:
progress_bar_val = notebook.tqdm(val_loader,desc = "Batch Number",leave =True)
trained_model.eval()
trained_model.to(DEVICE)
f1_running= 0
tp_running=0
tn_running=0
fp_running=0
fn_running=0
# out_img = torch.zeros_like(gold).float()
with torch.no_grad():
    for k,batch in enumerate(progress_bar_val):
        X,y =batch
        output = trained_model.forward(X.to(DEVICE))
    
#         for l,value in enumerate(output):
#             out_img[tuple(val_pixels[k*BATCH_SIZE+l])[:-1]] = torch.argmax(value)

        predicted=output.clone().detach()
        predicted=torch.argmax(output.clone().detach(),dim=1).unsqueeze(1)
        y = y.to(DEVICE)
        epsilon = 1e-5

        true_positives = torch.sum(torch.mul(predicted ,y)).float()
        true_negatives = torch.sum(torch.mul((1-predicted ),(1-y))).float()
        false_positives = torch.sum(torch.mul(predicted,(1-y))).float()
        false_negatives = torch.sum(torch.mul((1-predicted), y)).float()
        positives = true_positives + false_positives
        negatives = true_negatives + false_negatives
        accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives + epsilon)
        precision = true_positives / (true_positives + false_positives + epsilon)

        recall = true_positives / (true_positives + false_negatives + epsilon)
        f1 = (1.25*precision*recall/(0.25*precision+recall+epsilon)).item()
        tp_running+=true_positives
        tn_running+=true_negatives
        fp_running+=false_positives
        fn_running+=false_negatives
        f1_running = (1.25*tp_running/(1.25*tp_running+0.25*fn_running+fp_running+epsilon))
        if k%100 == 0:
            progress_bar_val.set_postfix({"Batch F-1 ":f1 ,"Accuracy ": accuracy.item() , "Total Positives in Batch ":y.sum().item(),"True Positives ":true_positives.item(), "Predicted_sum":predicted.sum().item()} )
            print(f1_running.item())



In [None]:
import shutil
shutil.make_archive("simonsZip", "zip", "/kaggle/working/lightning_logs")