In [12]:
import argparse
import numpy as np
import rasterio
from rasterio.io import MemoryFile
from rasterio.merge import merge
import os
import torch
from tqdm import tqdm

In [26]:
def predict(model, image, tile=True):
    if tile:
        out = torch.empty(size=(512,512))
        for i in range(2):
            for j in range(2):
                sub = image[:,:, i*256:i*256+256, j*256:j*256+256]
                out_ = model(sub)
                out_ = torch.nn.Softmax(dim=1)(out_)
                out_ = torch.argmax(out_, dim=1).squeeze().long().cpu().numpy()
                out[i*256:i*256+256, j*256:j*256+256] = out_
    else:
        out = model(image)  # assuming both are at the same device
        out = torch.nn.Softmax(dim=1)(out)
        out = torch.argmax(out, dim=1).squeeze().long().cpu().numpy() # this gives image without batch dim as i9niger data type
    return out

In [27]:
def preprocess(img, resze=False, size=512, channel_first=True):
    if resze:
        assert size is not None, 'Resie dimension is required.'
    
    if channel_first:
        size = (4, size, size)
    else:
        size = (size, size, 4)
    img = no.resize(img, new_size=size)
    img = (img.min())/((img.max()-img.min())+ 1e-7)
    
    img = torch.from_numy(img) 
    
    if not channel_first:
        img = img.permute(2, 0, 1).unsqueeze(0)
        
    return img

In [28]:
def load_model(channel, n_class, weights):
    model = UNet(n_channels=channel, n_classes=n_class, bilinear=True)
    model = model.load_state_dict(torch.load(weights))
    return model

In [29]:
def predict_moaic(files, n_channel, n_class, checkpoint, out_dir, shape):
    '''
    files: list of file pathes for rasters
    n_channel: number of channels in the image
    n_class: number of coutput classes from the model
    chkpoint_dir: checkpoint or model weight full path
    out_dir: the directory to save predicted and mosaiced image
    '''
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = load_model(channel=n_channel, n_class=n_class, weights=checkpoint)
    model = model.to(device)

    mosaic_container = []
    for file in tqdm(files):
        ins = rasterio.open(file)
        profile = ins.profile
        profile.update(count=1, dtype=np.uint8)
        arr = ins.read() # assuming its 4 chnnel image
        c, h, w = arr.shape

        pr_array = preprocess(img = arr, resize=True, size = shape, channel_first=True)
        pr_array = pr_array.to(device)

        pr_array_resize = np.resize(pr_array, new_size=(h, w))

        with MemoryFile() as memfile:
            with memfile.open(**profile) as dataset:
                dataset.write(pr_array_resize, 1)

            with memfile.open() as o_dataset:
                mosaic_container.append(o_dataset)

    mosaic, out_trans = merge(mosaic_container)
    out_meta.update({"height": mosaic.shape[1],
                     "width": mosaic.shape[2],
                     "transform": out_trans,
                       }
                        )

    out_fp = f'{out_dir}/mosaic.tif'

    with rasterio.open(out_fp, "w", **out_meta) as dest:
        dest.write(mosaic)

In [30]:
def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-l', '--checkpoint', dest='checkpoint', type=str, default='/home/getch/ssl/GORILLA/results/esalandcover/checkpoints/ff_best_weight.pth', help='Load model from a .pth file')
    parser.add_argument('-n', '--n_class', default=8, type=int, help='Number of classes in the mask/label and or model')
    parser.add_argument('-c', '--n_channel', default=4, type=int, help='Number of channels in the image')
    parser.add_argument('-t', '--data_dir', type=str, default='/home/getch/.cache/kagglehub/datasets/getachewworkineh/uganda-landcover/versions/1/samples/test/images', help='train data folder')
    parser.add_argument('-d', '--save_dir', type=str, default='/home/getch/ssl/GORILLA/results/full_preduction', help='directory to save the checkpoint and results')
    parser.add_argument('-x', '--ext', type=str, help='image and labnel file extension', default='.tif')
    parser.add_argument('-s', '--shape', type=int, default=512, help='image shape')
    return parser.parse_args()

In [33]:
def run_main(args):
    
    files = glob(f'{args.data_dir}/{args.ext}')
    
    if not os.exists(args.save_dir):
        os.makedirs(args.save_dir, exist_ok=True)
    
    predict_moaic(files=files,
                  n_channel=args.n_channel,
                  n_class=args.n_class,
                  checkpoint=args.checkpoint,
                  out_dir=args.save_dir,
                  shape=args.shape)
    
# if __name__ == '__main__':
#     args = get_args()
#     main(args=args)

In [34]:
args = get_args()
run_main(args=args)

usage: ipykernel_launcher.py [-h] [-l CHECKPOINT] [-n N_CLASS] [-c N_CHANNEL]
                             [-t DATA_DIR] [-d SAVE_DIR] [-x EXT] [-s SHAPE]
ipykernel_launcher.py: error: unrecognized arguments: --ip=127.0.0.1 --stdin=9013 --control=9011 --hb=9010 --Session.signature_scheme="hmac-sha256" --Session.key=b"304357fb-a649-48ee-b3b3-45083948aa63" --shell=9012 --transport="tcp" --iopub=9014 --f=/home/getch/.local/share/jupyter/runtime/kernel-v2-1955771RfL9ExH7X1KI.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [9]:
import os
from glob import glob
len(glob('/home/getch/.cache/kagglehub/datasets/getachewworkineh/uganda-landcover/versions/1/samples/test/images/*.tif'))

697

In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("getachewworkineh/uganda-landcover")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/getachewworkineh/uganda-landcover?dataset_version_number=2...


100%|██████████| 6.42G/6.42G [04:02<00:00, 28.5MB/s]  

Extracting files...





Path to dataset files: /home/getch/.cache/kagglehub/datasets/getachewworkineh/uganda-landcover/versions/2


In [5]:
import os
# os.listdir(path + '/' + 'landcover_data_v2')
path + '/' + 'landcover_data_v2'

'/home/getch/.cache/kagglehub/datasets/getachewworkineh/uganda-landcover/versions/2/landcover_data_v2'