In [1]:
from common import *

import os
from pathlib import Path
from typing import Optional
import numpy as np
import torch
import lightning as L
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from torch.utils.data import DataLoader
from torchmetrics import JaccardIndex

from minerva.data.datasets.supervised_dataset import SupervisedReconstructionDataset
from minerva.data.readers.png_reader import PNGReader
from minerva.data.readers.tiff_reader import TiffReader
from minerva.models.loaders import FromPretrained
from minerva.models.nets.image.vit import SFM_BasePatch16_Downstream
from minerva.pipelines.lightning_pipeline import SimpleLightningPipeline
from minerva.transforms.transform import _Transform, TransformPipeline
from lightning.pytorch.loggers.csv_logs import CSVLogger

  from .autonotebook import tqdm as notebook_tqdm


## Data Module

In [2]:
root_data_dir = "/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/images"
root_annotation_dir = "/workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/annotations"


data_module = GenericDataModule(
    root_data_dir=root_data_dir,
    root_annotation_dir=root_annotation_dir,
    transforms=[
        TransformPipeline([
            SelectChannel(0),
            PadCrop(512, 512, padding_mode="reflect", seed=42, constant_values=0),
            CastTo(np.float32),
        ]), 

        TransformPipeline([
            PadCrop(512, 512, padding_mode="reflect", seed=42, constant_values=0),
            CastTo(np.int64),
        ]), 
    ],
    batch_size=1,
    num_workers=1
)

data_module

DataModule
    Data: /workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/images
    Annotations: /workspaces/HIAAC-KR-Dev-Container/shared_data/seam_ai_datasets/seam_ai/annotations
    Batch size: 1

In [3]:
data_module.setup("fit")
batch = next(iter(data_module.train_dataloader()))
print(batch[0].shape, batch[1].shape)

torch.Size([1, 1, 512, 512]) torch.Size([1, 1, 512, 512])


## Model

In [4]:
model = SFM_BasePatch16_Downstream(
    img_size=(512, 512),
    num_classes=6,
    in_chans=1
)

model

SFM_BasePatch16_Downstream(
  (backbone): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(ap

In [5]:
model = FromPretrained(
    model=model,
    ckpt_path="/workspaces/HIAAC-KR-Dev-Container/shared_data/seismic_foundation_model/pretrained_models/SFM-Base-512.pth",
    ckpt_key="model",
    strict=False,
    filter_keys=["^blocks*", "^cls_token", "^pos_embed", "^patch_embed", "^norm"],
    keys_to_rename={"": "backbone.model."},
    ckpt_load_weights_only=False,
    error_on_missing_keys=False
)


Performing key renaming with: {'': 'backbone.model.'}
	Renaming key: cls_token -> backbone.model.cls_token (changed: True)
	Renaming key: pos_embed -> backbone.model.pos_embed (changed: True)
	Renaming key: patch_embed.proj.weight -> backbone.model.patch_embed.proj.weight (changed: True)
	Renaming key: patch_embed.proj.bias -> backbone.model.patch_embed.proj.bias (changed: True)
	Renaming key: blocks.0.norm1.weight -> backbone.model.blocks.0.norm1.weight (changed: True)
	Renaming key: blocks.0.norm1.bias -> backbone.model.blocks.0.norm1.bias (changed: True)
	Renaming key: blocks.0.attn.qkv.weight -> backbone.model.blocks.0.attn.qkv.weight (changed: True)
	Renaming key: blocks.0.attn.qkv.bias -> backbone.model.blocks.0.attn.qkv.bias (changed: True)
	Renaming key: blocks.0.attn.proj.weight -> backbone.model.blocks.0.attn.proj.weight (changed: True)
	Renaming key: blocks.0.attn.proj.bias -> backbone.model.blocks.0.attn.proj.bias (changed: True)
	Renaming key: blocks.0.norm2.weight -> back

## Trainer

In [6]:
log_dir = "./logs"
logger = CSVLogger(log_dir, name="dinov2", version="parihaka")
checkpoint = ModelCheckpoint(
    save_top_k=1,
    save_last=True,
)


trainer = L.Trainer(
    max_epochs=2,
    limit_train_batches=10,
    limit_val_batches=10,
    accelerator="gpu",
    devices=1,
    logger=logger,
    callbacks=[checkpoint],
)

pipeline = SimpleLightningPipeline(
    model=model,
    trainer=trainer,
    log_dir=log_dir + "/f3_segmentation",
    save_run_status=True
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/lightning/fabric/utilities/seed.py:42: No seed found, seed set to 0
Seed set to 0


Log directory set to: /workspaces/HIAAC-KR-Dev-Container/Minerva-Dev/docs/notebooks/examples/seismic/facies_classification/parihaka/logs/f3_segmentation


In [None]:
pipeline.run(data_module, task="fit")

/usr/local/lib/python3.10/dist-packages/lightning/fabric/loggers/csv_logs.py:268: Experiment logs directory ./logs/dinov2/parihaka exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory ./logs/dinov2/parihaka/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Pipeline info saved at: /workspaces/HIAAC-KR-Dev-Container/Minerva-Dev/docs/notebooks/examples/seismic/facies_classification/parihaka/logs/f3_segmentation/run_2024-12-05-00-13-2445e0ec351c4a4b709aebf4d6065b4054.yaml



  | Name     | Type              | Params | Mode 
-------------------------------------------------------
0 | backbone | VisionTransformer | 90.2 M | train
1 | fc       | Identity          | 0      | train
2 | loss_fn  | CrossEntropyLoss  | 0      | train
-------------------------------------------------------
90.2 M    Trainable params
0         Non-trainable params
90.2 M    Total params
360.821   Total estimated model params size (MB)
301       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


                                                                           

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 1: 100%|██████████| 10/10 [00:02<00:00,  3.90it/s, v_num=haka, val_loss=2.160, train_loss=1.610]