In [1]:
from src.unet.model import UNet

model = UNet.load_from_checkpoint('unet-resnet18-val_mathCorrCoef=0.8300.ckpt')
model.hparams

"batch_size":     5
"encoder":        resnet18
"gpus":           1
"log":            False
"lr":             0.001
"max_epochs":     60
"num_workers":    20
"path":           data/eopatches
"precision":      16
"resume":         None
"shuffle":        False
"train_batches":  1
"val_batches":    1
"val_with_train": True

In [2]:
from src.unet.datamodules.dm_fast import UNetDataModule

dm = UNetDataModule(batch_size=32)
dm.setup()

In [3]:
dm.test_df

Unnamed: 0,file,step
0,data/eopatches/test_processed/eopatch-12,0
1,data/eopatches/test_processed/eopatch-12,1
2,data/eopatches/test_processed/eopatch-12,2
3,data/eopatches/test_processed/eopatch-12,3
4,data/eopatches/test_processed/eopatch-12,4
...,...,...
898,data/eopatches/test_processed/eopatch-17,20
899,data/eopatches/test_processed/eopatch-17,21
900,data/eopatches/test_processed/eopatch-17,22
901,data/eopatches/test_processed/eopatch-17,23


In [4]:
patches = dm.test_df.file.unique()
assert len(patches) == 25

In [5]:
dm.test_df[dm.test_df.file == patches[0]]

Unnamed: 0,file,step
0,data/eopatches/test_processed/eopatch-12,0
1,data/eopatches/test_processed/eopatch-12,1
2,data/eopatches/test_processed/eopatch-12,2
3,data/eopatches/test_processed/eopatch-12,3
4,data/eopatches/test_processed/eopatch-12,4
5,data/eopatches/test_processed/eopatch-12,5
6,data/eopatches/test_processed/eopatch-12,6
7,data/eopatches/test_processed/eopatch-12,7
8,data/eopatches/test_processed/eopatch-12,8
9,data/eopatches/test_processed/eopatch-12,9


In [6]:
from src.unet.datasets.ds_fast import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm 
from eolearn.core import EOPatch
import rasterio
from rasterio.profiles import DefaultGTiffProfile

path = './data/eopatches/test'
submission_path = 'submission'
th = 0.5
for patch in tqdm(patches):
    patch_df = dm.test_df[dm.test_df.file == patch]
    ds = Dataset(patch_df, mode='test')
    dl = DataLoader(ds, batch_size=len(ds))
    imgs = next(iter(dl))
    preds = model.predict(imgs)
    mask = preds.mean(0).squeeze(0)[24:-24,24:-24] > th

    # save
    patch = patch.split("/")[-1]
    file_path = f'{path}/{patch}'
    eopatch = EOPatch.load(file_path)
    # Gracias a @cayala
    tfm = rasterio.transform.from_bounds(*eopatch.bbox, eopatch.meta_info['size_x'], eopatch.meta_info['size_y'])
    tfm = rasterio.Affine(tfm.a/4, tfm.b, tfm.c, tfm.d, tfm.e/4, tfm.f)
    profile = DefaultGTiffProfile(count=1)
    profile.update(
        transform=tfm,
        width=2000,
        height=2000,
        crs='epsg:32633'
    )
    with rasterio.open(f'{submission_path}/{patch}.tif', 'w', **profile) as dst_dataset:
        dst_dataset.write(mask, 1)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
 36%|███▌      | 9/25 [01:18<02:28,  9.28s/it]

In [None]:
!tar -C {submission_path} -zcvf submission.tar.gz .

In [None]:
sample = 'eopatch-12'
mask_ds = rasterio.open(f'submission/{sample}.tif')
mask_ds.crs, mask_ds.bounds, mask_ds.name

(CRS.from_epsg(32633),
 BoundingBox(left=629900.0, bottom=1565200.0, right=634900.0, top=1570200.0),
 'submission/eopatch-12.tif')

In [None]:
mask = mask_ds.read(1)
mask.dtype, mask.shape, mask.min(), mask.max()

(dtype('uint8'), (2000, 2000), 0, 1)