In [None]:
import numpy as np
import os
import glob
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader, Subset, TensorDataset
import pandas as pd
from sklearn.impute import SimpleImputer
import torch
from sklearn.model_selection import KFold
from monai.networks.nets import SwinUNETR, DynUNet, UNet
from monai.transforms import RandFlipd, RandRotate90d
from torch import nn
from transformers import get_cosine_schedule_with_warmup
from sklearn.metrics import accuracy_score
import torch.optim as optim
from tqdm import tqdm
import gc
import rasterio
from monai.losses import DiceLoss
from PIL import Image
import torchvision.transforms as transforms
from torchvision.transforms.v2 import GaussianNoise
import random
from monai.inferers import SlidingWindowInferer

In [None]:
model = DynUNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=3,
    kernel_size=(3, 3, 3, 3, 3, 3, 3),
    strides=(1, 2, 2, 2, 2, 2, 2),
    upsample_kernel_size=(2, 2, 2, 2, 2, 2),
    deep_supervision=True, 
    deep_supr_num=2, 
    dropout=0.2, 
    filters=[64, 96, 128, 192, 256, 384, 512]
)

In [None]:
def predict(data, model, model_name, batch_size=32, n_fold=0, device='cuda', delete=False):
    model     = model.to(device) # load the model into the GPU
    model.load_state_dict(torch.load(os.path.join(model_name + str(n_fold), 'checkpoint.pth')))
    
    model.eval()
    with torch.no_grad():
        inferer = SlidingWindowInferer(roi_size=(128, 128), sw_batch_size=batch_size, overlap=0.5, mode="gaussian", progress=True, sw_device=device, device=torch.device('cpu'))
        outputs = inferer(data, model)
    return outputs

In [None]:
with rasterio.open('dataset/result.tif') as src:
    data = src.read()[:3, ...]
print(data.shape)
gc.collect()
preds_i = []
batch_size = 128
for i in range(0, data.shape[1], 5000):
    preds_j = []
    for j in range(0, data.shape[2], 5000):
        start_i = i
        end_i = i + 5000 if i + 5000 <= data.shape[1] else data.shape[1]
        start_j = j
        end_j = j + 5000 if j + 5000 <= data.shape[2] else data.shape[2]
        chunk_data = data[:, start_i: end_i, start_j: end_j]     
        chunk_data = torch.tensor(chunk_data) / 255.0
        fold_preds = []
        for n_fold in range(5):
            preds = predict(chunk_data.float().unsqueeze(0), model, 'DynUNet', batch_size, n_fold, 'cuda')
            fold_preds.append(preds)
        preds = np.stack(fold_preds, axis=0).mean(axis=0)
        preds = torch.softmax(preds, dim=1).numpy()
        preds = preds.argmax(axis=1).astype(np.uint8)
        preds_j.append(preds)
        gc.collect()
    preds_j = np.concatenate(preds_j, axis=2)
    preds_i.append(preds_j)
    gc.collect()
preds = np.concatenate(preds_i, axis=1)
del data
del preds_i
del preds_j
gc.collect()

In [None]:
with rasterio.open('dataset/result.tif') as src:
    mask = src.read()[-1:, ...]
preds = np.where(mask == 0, 0, preds)
del mask
gc.collect()

In [2]:
with rasterio.open('dataset/result.tif') as src:
    meta = src.meta.copy()
print(meta)
meta['count'] = 1
meta['nodata'] = None
with rasterio.open('pred.tif', "w", **meta) as dst:
    dst.write(preds)

{'driver': 'GTiff', 'dtype': 'uint8', 'nodata': 0.0, 'width': 37677, 'height': 35068, 'count': 4, 'crs': CRS.from_wkt('PROJCS["WGS 84 / UTM zone 51N",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]],PROJECTION["Transverse_Mercator"],PARAMETER["latitude_of_origin",0],PARAMETER["central_meridian",123],PARAMETER["scale_factor",0.9996],PARAMETER["false_easting",500000],PARAMETER["false_northing",0],UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH],AUTHORITY["EPSG","32651"]]'), 'transform': Affine(0.0626, 0.0, 264493.5783968877,
       0.0, -0.0626, 4512247.584260046)}
{'driver': 'GTiff', 'dtype': 'uint8', 'nodata': None, 'width': 37677, 'height': 35068, 'count': 1, 'crs': CRS.from_wkt('PROJCS["WGS 84 / UTM zone 51N",GEOGCS["WGS 84",DATUM["WGS_1984",SP