In [1]:
import numpy as np
from pathlib import Path
from minerva.data.readers.tiff_reader import TiffReader
from minerva.data.readers.png_reader import PNGReader
import plotly.express as px
import pandas as pd
import plotly.graph_objects as go


from common import get_data_module
import torch
from functools import partial
from common import get_evaluation_pipeline

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
root_data_dir = "/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/images"
root_annotation_dir = "/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/annotations"

img_size = (1008, 784)          # Change this to the size of the images in the dataset
model_name = "dinov2_dpt"       # Model name (just identifier)
dataset_name = "seam_ai"        # Dataset name (just identifier)
single_channel = False          # If True, the model will be trained with single channel images (instead of 3 channels)

log_dir = Path("./logs")              # Directory to save logs
batch_size = 1                  # Batch size    
seed = 42                       # Seed for reproducibility
num_epochs = 100                # Number of epochs to train
is_debug = False                 # If True, only 3 batch will be processed for 3 epochs
accelerator = "gpu"             # CPU or GPU
devices = 1                     # Num GPUs

In [3]:
data_module = get_data_module(
    root_data_dir=root_data_dir,
    root_annotation_dir=root_annotation_dir,
    img_size=img_size,
    batch_size=batch_size,
    seed=seed,
    single_channel=single_channel, 
)

data_module

DataModule
    Data: /workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/images
    Annotations: /workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/annotations
    Batch size: 1

In [4]:
# Just to check if the data module is working
data_module.setup("predict")
train_batch_x, train_batch_y = next(iter(data_module.predict_dataloader()))
train_batch_x.shape, train_batch_y.shape

(torch.Size([1, 3, 1006, 590]), torch.Size([1, 1006, 590]))

In [5]:
models_predictions = {}
for f in sorted(log_dir.rglob("predictions.npy")):
    model_name = f.parents[2].name
    data = np.load(f)
    models_predictions[model_name] = data
    print(f"{model_name:<20}: {data.shape} ({data.dtype})")

dinov2_dpt          : (200, 6, 1008, 784) (float32)
dinov2_mla          : (200, 6, 1008, 784) (float32)


In [None]:
img_no = 100

import matplotlib.pyplot as plt

for model_name, predictions in models_predictions.items():
    plt.figure(figsize=(15, 10))
    plt.subplot(1, 3, 1)
    plt.imshow(X[img_no])
    plt.title("Image")
    plt.axis("off")
    
    plt.subplot(1, 3, 2)
    plt.imshow(y[img_no])
    plt.title("Label")
    plt.axis("off")
    
    plt.subplot(1, 3, 3)
    plt.imshow(predictions[img_no])
    plt.title(f"{model_name} prediction")
    plt.axis("off")
    
    plt.show()