# Prediction and Submission

This notebook loads a previously trained model, the test dataset and saves predictions compliant with the competition submission format.

In [1]:
import os
from pathlib import Path

from biomasstry.datasets import TemporalSentinel2Dataset
from biomasstry.models import TemporalSentinelModel, UTAE
from PIL import Image
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

In [2]:
ARTIFACTS_DIR = Path("/notebooks/artifacts")
PREDICTIONS_DIR = Path("/notebooks/submission")
model_file = "20230102_TemporalS2_B16_E10.pt"
model_file = "20230109_TemporalS2_B64_E10.pt"
model_file = "20230112_UTAE_S2_B32_E20.pt"
model_path = ARTIFACTS_DIR / model_file

## Test Data

In [3]:
# Testing Dataset
testds = TemporalSentinel2Dataset(train=False)

## Pre-trained Model

In [4]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [5]:
print(f"Device: {device}")

Device: cuda


In [6]:
n_tsamples = 5
input_nc = 10
# model = TemporalSentinelModel(n_tsamples=n_tsamples, 
#     input_nc=input_nc,
#     output_nc=1)
model = UTAE(input_nc)
model.load_state_dict(torch.load(model_path))
model.to(device)
model.eval()

UTAE(
  (in_conv): ConvBlock(
    (conv): ConvLayer(
      (conv): Sequential(
        (0): Conv2d(10, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
        (1): GroupNorm(4, 64, eps=1e-05, affine=True)
        (2): ReLU()
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
        (4): GroupNorm(4, 64, eps=1e-05, affine=True)
        (5): ReLU()
      )
    )
  )
  (down_blocks): ModuleList(
    (0): DownConvBlock(
      (down): ConvLayer(
        (conv): Sequential(
          (0): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
          (1): GroupNorm(4, 64, eps=1e-05, affine=True)
          (2): ReLU()
        )
      )
      (conv1): ConvLayer(
        (conv): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
          (1): GroupNorm(4, 64, eps=1e-05, affine=True)
          (2): ReLU()
        )
   

In [7]:
def predict_agbm(inputs, model):
    pred = model(inputs)
    return pred.detach().squeeze().cpu().numpy()

In [8]:
def save_agbm(agbm_pred, chipid):
    im = Image.fromarray(agbm_pred)
    save_path = os.path.join(PREDICTIONS_DIR, f'{chipid}_agbm.tif')
    im.save(save_path, format='TIFF', save_all=True)

In [None]:
for timg in tqdm(testds):
    chipid = timg['chip_id'] # testds.chip_ids[ix]
    inputs = timg['image']
    inputs = torch.stack(timg['image']).unsqueeze(dim=0).to(device)
    # [img.unsqueeze(0).to(device) for img in timg['image']]
    agbm = predict_agbm(inputs, model)
    save_agbm(agbm, chipid)

  0%|          | 0/2773 [00:00<?, ?it/s]