In [1]:
import time
import pathlib

import tqdm
import torch
import skimage
import rasterio
import matplotlib.pyplot as plt

from hashtagdeep.models import FCDenseNet103
from hashtagdeep.dataset import MiniFranceDFC22
from hashtagdeep.utils.colormaps import dfc22_labels_palette

In [2]:
dataset = MiniFranceDFC22(
    base_dir='/home/dubrovin/Projects/Data/DFC2022/',
    labeled=False,
    val=True,
    augmentation=None,
    transform=None,
)

In [3]:
model = FCDenseNet103(in_channels=4, n_classes=14)
state_dict = torch.load('checkpoints/state_dicts_for_submission_models/experiment00_v1.pt')
model.load_state_dict(state_dict)
model.eval()
model.to('cuda:1');

In [4]:
base_predictions_dir = pathlib.Path('predictions0')
base_predictions_dir.mkdir(exist_ok=True)

for i, item in tqdm.tqdm(enumerate(dataset), total=len(dataset)):
    x = item['image']
    height, width = x.shape[1:]
    out = torch.empty(height, width, dtype=torch.int8)

    with torch.no_grad():
        out[:1000, :1000] = model(x[:, :1000, :1000].unsqueeze(0).to('cuda:1')).cpu().squeeze().argmax(0)
        out[1000:, :1000] = model(x[:, 1000:, :1000].unsqueeze(0).to('cuda:1')).cpu().squeeze().argmax(0)
        out[:1000, 1000:] = model(x[:, :1000, 1000:].unsqueeze(0).to('cuda:1')).cpu().squeeze().argmax(0)
        out[1000:, 1000:] = model(x[:, 1000:, 1000:].unsqueeze(0).to('cuda:1')).cpu().squeeze().argmax(0)
    
    # transform class indices into class labels
    out += 1
    
    path = dataset.true_color_paths[i]
    region, _, file = str(path).split('/')[-3:]
    
    region_dir = base_predictions_dir / region
    region_dir.mkdir(exist_ok=True)
    
    file_basename = file.split('.')[0]
    prediction_file = region_dir / f'{file_basename}_prediction.tif'
    
    with rasterio.open(path) as src:
        transform = src.transform
        crs = src.crs
    
    with rasterio.open(
        prediction_file, 'w',
        height=height, width=width, count=1, dtype='uint8',
        transform=transform, crs=crs,
    ) as dst:
        dst.write(out[None, :, :])

100%|███████████████████████████████████████████████| 2066/2066 [1:05:03<00:00,  1.89s/it]
