# The SWIN UNETR model from MONAI
The SWIN UNETR model was adapted from MONAI and adjusted to 2D version of the dataset

## Load libraries

In [1]:
import os
import shutil
import tempfile
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

from monai.transforms import (
    AsDiscrete,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    EnsureTyped,
)
from monai.data import (
    ThreadDataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
    set_track_meta,
)
from monai.inferers import sliding_window_inference
from monai.networks.nets import SwinUNETR, AttentionUnet
from monai.metrics import DiceMetric
from monai.losses import DiceCELoss
from monai.losses import DiceLoss, TverskyLoss, FocalLoss

import torch
import einops
import warnings




warnings.filterwarnings("ignore")
import torch
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available() 
torch.cuda.device_count()
torch.cuda.current_device()
torch.cuda.get_device_name(0)

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


'NVIDIA RTX A4000'

### Check your location

In [None]:

# Check the amount of shared memory
os.system('df -h /dev/shm')

import os
dir = os.getcwd()
print('Current directory is:',dir)


warnings.filterwarnings("ignore")

## Set the model parametres and predict

The code allows you to pick the image

In [3]:
import json
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import os
import rasterio
from rasterio.windows import Window
import tifffile as tiff


def get_filename_without_extension():
    data_to_detect = input("Enter the path to the image: ").strip('"')
    filename_to_detect = os.path.splitext(os.path.basename(data_to_detect))[0]
    return filename_to_detect
filename_to_detect = get_filename_without_extension()

class NumpyDataset(Dataset):
    def __init__(self, image_paths, label_data):
        self.image_paths = image_paths
        self.label_data = label_data  # Accept either paths or preloaded arrays

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = np.load(self.image_paths[idx])

        # Check if label_data contains paths or arrays
        if isinstance(self.label_data[idx], str):
            label = np.load(self.label_data[idx])
        else:
            label = self.label_data[idx]

        image = self.replace_nans_in_array(image)
        label = self.replace_nans_in_array(label)

        image = torch.tensor(image, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.float32)

        return image, label, self.image_paths[idx]

    @staticmethod
    def replace_nans_in_array(arr):
        arr[np.isnan(arr)] = 0
        arr[np.isinf(arr)] = 0
        return arr


def load_dataset_json(json_path):
    with open(json_path, 'r') as file:
        dataset_json = json.load(file)
    return dataset_json

def prepare_test_loader(test_image_paths, batch_size):
    dummy_label_dir = "./model/temp_labels/"
    if not os.path.exists(dummy_label_dir):
        os.makedirs(dummy_label_dir)

    test_labels = []
    for image_path in test_image_paths:
        dummy_label_path = os.path.join(dummy_label_dir, os.path.basename(image_path).replace('.npy', '_label.npy'))
        np.save(dummy_label_path, np.zeros_like(np.load(image_path)))
        test_labels.append(dummy_label_path)

    test_dataset = NumpyDataset(test_image_paths, test_labels)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return test_loader

def initialize_model():
    """Initialize the SwinUNETR model."""
    model = SwinUNETR(
        img_size=(96, 96),
        in_channels=48,
        out_channels=1,  # Use the passed `num_classes`
        use_checkpoint=True,
        feature_size=48,
        depths=(3, 9, 18, 3),
        num_heads=(4, 8, 16, 32),
        drop_rate=0.1,  # Added dropout
        attn_drop_rate=0.1,
        dropout_path_rate=0.2,
        spatial_dims=2
    )
    return model


def load_adapted_model(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint, strict=False)
    return model

### Load model

def load_large_image(image_path):
    print(f"Loading large image from {image_path}")
    with rasterio.open(image_path) as src:
        image = src.read()
        transform = src.transform  # Capture the affine transformation matrix
    return image, transform

def split_to_tiles(image, tile_size, save_dir):
    print(f"Splitting image into tiles of size {tile_size}x{tile_size}")
    tiles = []
    num_channels, height, width = image.shape
    for i in range(0, height, tile_size):
        for j in range(0, width, tile_size):
            window = Window(j, i, tile_size, tile_size)
            tile = image[:, i:i+tile_size, j:j+tile_size]
            if tile.shape[1] == tile_size and tile.shape[2] == tile_size:
                tile_path = os.path.join(save_dir, f'tile_{i}_{j}.tif')
                print(f"Saving tile to {tile_path}")
                tiff.imwrite(tile_path, tile)
                tiles.append(tile_path)
    print(f"Total number of tiles: {len(tiles)}")
    return tiles

def convert_tiles_to_npy(tile_paths, npy_dir):
    if not os.path.exists(npy_dir):
        os.makedirs(npy_dir)
    npy_paths = []
    for tile_path in tile_paths:
        image = tiff.imread(tile_path)
        npy_path = os.path.join(npy_dir, os.path.basename(tile_path).replace('.tif', '.npy'))
        print(f"Converting {tile_path} to {npy_path}")
        np.save(npy_path, image)
        npy_paths.append(npy_path)
    return npy_paths

def merge_tiles(tiles, image_shape, tile_size):
    print(f"Merging tiles back into full image of shape {image_shape}")
    _, height, width = image_shape  # Assume image_shape is in the format (num_channels, height, width)
    full_image = np.zeros((height, width), dtype=np.uint8)  # Only one channel for the full image

    for tile_path in tiles:
        tile = tiff.imread(tile_path)
        # Adjust the parsing to match the filename pattern used in split_to_tiles
        filename = os.path.basename(tile_path)
        parts = filename.replace('tile_', '').replace('.tif', '').split('_')
        if len(parts) == 2:  # Ensure the filename format is correct
            i, j = map(int, parts)
            full_image[i:i+tile_size, j:j+tile_size] = tile  # Single channel tile assignment
        else:
            print(f"Unexpected filename format: {filename}")

    return full_image
### 

def predict_and_save(model, test_loader, device, tile_paths, save_dir='model/temp_pred'):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    print(f"Starting prediction on device: {device}")
    model.eval()
    model.to(device)
    predictions = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            images, _, _ = batch  # Unpack the batch; labels and paths are not needed for predictions
            print(f"Predicting batch {batch_idx + 1}/{len(test_loader)}")
            print(f"Images type: {type(images)}, Images shape: {images.shape if isinstance(images, torch.Tensor) else 'unknown'}")

            images = images.to(device)  # Move images to the correct device
            outputs = model(images)
            outputs = torch.sigmoid(outputs)
            outputs = outputs.cpu().numpy()

            # Save the predictions for each image in the batch
            for i in range(images.shape[0]):
                output_image = outputs[i]
                binary_output = (output_image > 0.5).astype(np.uint8)  # Apply threshold to create binary output
                original_tile_name = os.path.basename(tile_paths[batch_idx * test_loader.batch_size + i]).replace('.npy', '.tif')
                save_path = os.path.join(save_dir, original_tile_name)
                print(f"Saving prediction to {save_path}")
                tiff.imwrite(save_path, binary_output)
                predictions.append(save_path)

    return predictions



def save_full_raster(predictions, image_shape, tile_size, save_path):
    print(f"Saving full raster to {save_path}")
    full_raster = merge_tiles(predictions, image_shape, tile_size)
    with rasterio.open(save_path, 'w', driver='GTiff', height=full_raster.shape[0],
                       width=full_raster.shape[1], count=1, dtype=full_raster.dtype) as dst:
        dst.write(full_raster, 1)

def create_tfw_file(transform, tfw_path):
    print(f"Creating .tfw file at {tfw_path}")
    with open(tfw_path, 'w') as f:
        f.write(f"{transform.a}\n")  # pixel size in the x-direction
        f.write(f"{transform.b}\n")  # rotation term (always 0 for north-up images)
        f.write(f"{transform.d}\n")  # rotation term (always 0 for north-up images)
        f.write(f"{transform.e}\n")  # pixel size in the y-direction (usually negative)
        f.write(f"{transform.c}\n")  # x-coordinate of the upper-left corner of the upper-left pixel
        f.write(f"{transform.f}\n")  # y-coordinate of the upper-left corner of the upper-left pixel

if __name__ == "__main__":
    image_path = f'data/{filename_to_detect}.tif'
    tile_size = 96
    batch_size = 12
    model = initialize_model()

    print("Loading and preprocessing the large image")
    large_image, transform = load_large_image(image_path)
    large_image = NumpyDataset.replace_nans_in_array(large_image)
    image_shape = large_image.shape
    print(f"Large image shape: {image_shape}")

    temp_dir = './model/temp/'
    npy_dir = './model/temp_npy/'
    pred_dir = './model/temp_pred/'
    result_path = f'./model/Swin_UNETR/Aug/{filename_to_detect}_swinunetr.tif'
    tfw_path = f'./model/Swin_UNETR/Aug/{filename_to_detect}_swinunetr.tfw'

    if not os.path.exists(temp_dir):
        os.makedirs(temp_dir)

    tile_paths = split_to_tiles(large_image, tile_size, temp_dir)
    npy_paths = convert_tiles_to_npy(tile_paths, npy_dir)
    test_loader = prepare_test_loader(npy_paths, batch_size)

    adapted_model_path = 'model/Swin_UNETR/Aug/bin_best_model_lr_B.pth'
    model = load_adapted_model(model, adapted_model_path)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    predictions = predict_and_save(model, test_loader, device, tile_paths, save_dir=pred_dir)

    # Adjust the image_shape to reflect the single-channel output
    single_channel_image_shape = (1, image_shape[1], image_shape[2])
    save_full_raster(predictions, single_channel_image_shape, tile_size, result_path)

    # Create the .tfw file using the transform from the original large image
    create_tfw_file(transform, tfw_path)

    print("Prediction, saving, and .tfw file creation complete")



Enter the path to the image:  Swin_UNETR_N48XL/data/M1_selected_L5_S2_S1_MSRM_PCA_N48-0000000000-0000000000.tif


Loading and preprocessing the large image
Loading large image from data/M1_selected_L5_S2_S1_MSRM_PCA_N48-0000000000-0000000000.tif
Large image shape: (48, 4864, 3456)
Splitting image into tiles of size 96x96
Saving tile to ./model/temp/tile_0_0.tif
Saving tile to ./model/temp/tile_0_96.tif
Saving tile to ./model/temp/tile_0_192.tif
Saving tile to ./model/temp/tile_0_288.tif
Saving tile to ./model/temp/tile_0_384.tif
Saving tile to ./model/temp/tile_0_480.tif
Saving tile to ./model/temp/tile_0_576.tif
Saving tile to ./model/temp/tile_0_672.tif
Saving tile to ./model/temp/tile_0_768.tif
Saving tile to ./model/temp/tile_0_864.tif
Saving tile to ./model/temp/tile_0_960.tif
Saving tile to ./model/temp/tile_0_1056.tif
Saving tile to ./model/temp/tile_0_1152.tif
Saving tile to ./model/temp/tile_0_1248.tif
Saving tile to ./model/temp/tile_0_1344.tif
Saving tile to ./model/temp/tile_0_1440.tif
Saving tile to ./model/temp/tile_0_1536.tif
Saving tile to ./model/temp/tile_0_1632.tif
Saving tile t