In [42]:
import os

custom_imports = dict(imports=["geospatial_fm"])

# base options
dist_params = dict(backend="nccl")
log_level = "INFO"
load_from = None
resume_from = None
cudnn_benchmark = True

# Define a custom dataset for similarity learning (provide paths to pairs of similar )
dataset_type = "SiameseGeospatialDataset"
samples_per_gpu = 2
# TO BE DEFINED BY USER: data directory
data_root = "C:/Users/safae/Downloads/NASAdata/NASAdata"
num_workers = 2
bands = [0, 1, 2, 3, 4, 5]
tile_size = 224
orig_nsize = 512
crop_size = (tile_size, tile_size)
img_suffix = "_merged.tif"
seg_map_suffix = ".mask.tif"
ignore_index = -1
image_nodata = -9999
image_nodata_replace = 0
image_to_float32 = True
CLASSES = ("similar", "dissimilar")

In [43]:
patch_size = 16
embed_dim = 768
num_frames = 2
num_heads = 8
tubelet_size = 1
max_epochs = 80
eval_epoch_interval = 5

loss_weights_multi = [
    0.386375, 0.661126, 0.548184, 0.640482, 0.876862, 0.925186, 3.249462,
    1.542289, 2.175141, 2.272419, 3.062762, 3.626097, 1.198702
]

In [44]:
train_pipeline = [
    dict(
        type="LoadGeospatialImageFromFile",
        to_float32=image_to_float32,
        channels_last=True
    ),
    dict(type="LoadGeospatialAnnotations", reduce_zero_label=False),
    dict(type="BandsExtract", bands=bands),
    dict(type="RandomFlip", prob=0.5),
    dict(type="ToTensor", keys=["img", "gt_semantic_seg"]),
    # to channels first
    dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
    dict(type="TorchRandomCrop", crop_size=(tile_size, tile_size)),
    dict(
        type="Reshape",
        keys=["img"],
        new_shape=(
            len(bands),
            num_frames,
            tile_size,
            tile_size
        )
    ),
    dict(
        type="Reshape",
        keys=["gt_semantic_seg"],
        new_shape=(1, tile_size, tile_size)
    ),
    dict(
        type="CastTensor",
        keys=["gt_semantic_seg"],
        new_type="torch.LongTensor"
    ),
    dict(type="Collect", keys=["img", "gt_semantic_seg"])
]
test_pipeline = [
    dict(
        type="LoadGeospatialImageFromFile",
        to_float32=image_to_float32,
        channels_last=True
    ),
    dict(type="BandsExtract", bands=bands),
    dict(type="ToTensor", keys=["img"]),
    # to channels first
    dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
    dict(
        type="Reshape",
        keys=["img"],
        new_shape=(len(bands), num_frames, -1, -1),
        look_up=dict({
            "2": 1,
            "3": 2
        })),
    dict(type="CastTensor", keys=["img"], new_type="torch.FloatTensor"),
    dict(
        type="CollectTestList",
        keys=["img"],
        meta_keys=[
            "img_info",
            "seg_fields",
            "img_prefix",
            "seg_prefix",
            "filename",
            "ori_filename",
            "img",
            "img_shape",
            "ori_shape",
            "pad_shape",
            "scale_factor",
        ]
    )
]


In [45]:
data = dict(
    samples_per_gpu=samples_per_gpu,
    workers_per_gpu=num_workers,
    train=dict(
        type=dataset_type,
        data_root=data_root,
        similar_pairs_file="C:/Users/safae/Downloadssimilarimages",
        img_suffix=img_suffix,
        pipeline=train_pipeline,
    ),
    # Validation and test datasets can also be provided with similar/dissimilar pairs
    # ...
)

In [46]:
splits = dict(
    train='C:\Users\safae\Downloads\TrainData',
    val= 'C:/Users/safae/Downloads/ValData',
    test=  'C:\Users\safae\Downloads\ValData'
)

In [33]:
# Define a custom loss function for similarity learning
loss_func = dict(
    type="ContrastiveLoss",  # You can use ContrastiveLoss or TripletLoss
    margin=0.5,  # Margin parameter for contrastive loss
    pos_weight=1.0,  # Weight for the positive pairs loss
    neg_weight=1.0,  # Weight for the negative pairs loss
)
print(loss_func)

{'type': 'ContrastiveLoss', 'margin': 0.5, 'pos_weight': 1.0, 'neg_weight': 1.0}


In [37]:
optimizer = dict(type="Adam", lr=1.3e-05, betas=(0.9, 0.999))
optimizer_config = dict(grad_clip=None)
lr_config = dict(
    policy="poly",
    warmup="linear",
    warmup_iters=1500,
    warmup_ratio=1e-06,
    power=1.0,
    min_lr=0.0,
    by_epoch=False
)
log_config = dict(
    interval=20,
    hooks=[
        dict(type="TextLoggerHook", by_epoch=False),
        dict(type="TensorboardLoggerHook", by_epoch=False)
    ]
)
checkpoint_config = dict(
    by_epoch=True,
    interval=10,
)
evaluation = dict(
    metric="mIoU",
    pre_eval=True,
    save_best="mIoU",
    by_epoch=False
)

loss_func = dict(
    type="DiceLoss",
    use_sigmoid=False,
    loss_weight=1,
    ignore_index=-1
)

workflow = [("train", 1)]
norm_cfg = dict(type="BN", requires_grad=True)


In [41]:
model = dict(
    type="TemporalEncoderDecoder",
    frozen_backbone=False,
    backbone=dict(
        type="TemporalViTEncoder",
        patch_size=patch_size,
        num_frames=num_frames,
        tubelet_size=tubelet_size,
        in_chans=len(bands),
        embed_dim=embed_dim,
        depth=12,
        num_heads=num_heads,
        mlp_ratio=4.0,
        norm_pix_loss=False
    ),
    neck=dict(
        type="ConvTransformerTokensToEmbeddingNeck",
        embed_dim=embed_dim*num_frames,
        output_embed_dim=output_embed_dim,
        drop_cls_token=True,
        Hp=14,
        Wp=14
    ),
    decode_head=dict(
        num_classes=len(CLASSES),
        in_channels=output_embed_dim,
        type="FCNHead",
        in_index=-1,
        channels=256,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        norm_cfg=dict(type="BN", requires_grad=True),
        align_corners=False,
        loss_decode=loss_func
    ),
    auxiliary_head=dict(
        num_classes=len(CLASSES),
        in_channels=output_embed_dim,
        type="FCNHead",
        in_index=-1,
        channels=256,
        num_convs=2,
        concat_input=False,
        dropout_ratio=0.1,
        norm_cfg=dict(type="BN", requires_grad=True),
        align_corners=False,
        loss_decode=loss_func
    ),
    train_cfg=dict(),
    test_cfg=dict(
        mode="slide",
        stride=(int(tile_size / 2), int(tile_size / 2)),
        crop_size=(tile_size, tile_size),
    ),
)
gpu_ids = range(0, 1)
auto_resume = False
