# Produce thematic map from model

In [2]:
# get torch functions
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torch.autograd.profiler

# get image process functions
import cv2
from PIL import Image
#import torchvision
#from scipy import ndimage

# get linear algebra plus vis 
import matplotlib.pyplot as plt
import numpy as np

# get UNET
from model_struct import *

# get subshell functions
from glob import glob
import os, sys
#os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [None]:
root = 'some_path.../Unet_binary'
used_model = '/585_db/saved_model/model_at_epoch_8_0.01262.pt'
model = glob(root + used_model)
model

In [4]:
device = torch.device('cpu')
loaded_m = UNet(n_channels=3, n_classes=1, bilinear=True).to(device)
loaded_m.load_state_dict(torch.load(model[0]))
loaded_m.eval()

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, moment

In [5]:
# Util to read files in alphanum order
import re
def sorted_alphanumeric(data):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
    return sorted(data, key=alphanum_key)

In [6]:
# Define Dataset
class simple_segmentDS(torch.utils.data.Dataset):        
    def __init__(self, root, imdir, training=False, transform=None):
            super(simple_segmentDS, self).__init__()
            self.root = root
            self.imdir = imdir
            self.training = training
            self.transform = transform
            self.IMG_NAMES = sorted_alphanumeric(glob(self.root+self.imdir+'*.jpg'))
            
    def __len__(self):
        return len(self.IMG_NAMES)

    def __getitem__(self, idx):
        img_path = self.IMG_NAMES[idx]     
        image = np.array(Image.open(img_path).convert("RGB"))

        if self.training == True:
            if self.transform is not None:
                image = transforms.functional.to_pil_image(image)
                image = self.transform(image)
                image = np.array(image)
    
        image = cv2.resize(image, (201,201))/255.0
        image = np.moveaxis(image, -1, 0)

        return torch.tensor(image).float()

In [7]:
root_2 = '....some_path../raster_calc_gdal'
imdir = '/T34/arable_mask/cropped/T34TCT_TCI/'

testset = simple_segmentDS(root_2, imdir, training=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=9, shuffle=False)

In [8]:
# images/batch (2916 tile = 324 image in 9 batches)
print(len(testloader))

324


In [9]:
# Apply model on images
# took around 10-12 mins on 2916 images (201x201)

import torchvision
import torchvision.transforms

tosave = root_2 + imdir + "preds"
with torch.no_grad():
    for batch, images in enumerate(testloader):
        for j in range(len(images)):
            pred = torch.sigmoid(loaded_m(images[j:j+1]))
            pred = (pred > 0.5).float()
            torchvision.utils.save_image(pred, f"{tosave}/pred_{batch}_{j}.jpg")
        