In [1]:
import re
import torch
import xarray
import rasterio
import torch.nn as nn
from glob import glob
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2 as transforms


batch_size = 64
data_cubes_dir = "./Data/Images/Data Cubes"
train_dir = f"{data_cubes_dir}/Train"
val_dir = f"{data_cubes_dir}/Val"
test_dir = f"{data_cubes_dir}/Test"
train_cubes_paths = glob(f"{train_dir}/*.zarr")
val_cubes_paths = glob(f"{val_dir}/*.zarr")
test_cubes_paths = glob(f"{test_dir}/*.zarr")
all_cubes_paths = train_cubes_paths + val_cubes_paths + test_cubes_paths


class AtlanticForestDataset(Dataset):
  def __init__(self, zarr_paths, transform=None):
    self.zarr_paths = zarr_paths
    self.transform = transform

  def __len__(self):
    return len(self.zarr_paths)

  def __getitem__(self, idx):
    zarr_path = self.zarr_paths[idx]
    data_cube = xarray.open_zarr(zarr_path, consolidated=False)['data_cube']
    tensor = torch.from_numpy(data_cube.values).float()
    tensor = torch.nan_to_num(tensor)
    if self.transform:
      tensor = torch.transpose(tensor, 0, 1)
      tensor = self.transform(tensor)
      tensor = torch.transpose(tensor, 0, 1)
    return tensor


# Using the mean and std from the train dataset (check AutoEncoder Train for more details)
mean = torch.tensor([881.1200, 996.9095, 839.9584, 3211.0962, 1839.4138, 5219.2334, 2899.2646])
std = torch.tensor([1692.1893, 1562.1083, 1555.0720, 1280.2616, 1096.6338, 2067.6047, 1164.6179])


transform = transforms.Compose([
  transforms.Normalize(mean=mean.tolist(), std=std.tolist())
])


all_dataset = AtlanticForestDataset(all_cubes_paths, transform=transform)
all_dataloader = DataLoader(all_dataset, batch_size=batch_size, shuffle=False)

In [2]:
class AutoEncoder(nn.Module):
  def __init__(self):
    super(AutoEncoder, self).__init__()
    self.encoder = nn.Sequential(
      nn.Conv3d(
        in_channels=7, 
        out_channels=5, 
        kernel_size=(3, 3, 3),
        stride=(2, 1, 1),
        padding=(0, 1, 1)),
      nn.BatchNorm3d(5),
      nn.LeakyReLU(),
      nn.Dropout(0.2),
      nn.Conv3d(
        in_channels=5, 
        out_channels=5, 
        kernel_size=(5, 3, 3),
        stride=(1, 1, 1),
        padding=(0, 1, 1)),
      nn.BatchNorm3d(5),
      nn.LeakyReLU(),
      nn.Dropout(0.2),
      nn.Conv3d(
        in_channels=5, 
        out_channels=3, 
        kernel_size=(5, 3, 3),
        stride=(1, 1, 1),
        padding=(0, 1, 1)),
      nn.BatchNorm3d(3),
      nn.LeakyReLU(),
      nn.Dropout(0.2),
      nn.Conv3d(
        in_channels=3, 
        out_channels=1, 
        kernel_size=(3, 3, 3),
        stride=(1, 1, 1),
        padding=(0, 1, 1)),
      nn.BatchNorm3d(1)
    )
    self.decoder = nn.Sequential(
      nn.ConvTranspose3d(
        in_channels=1, 
        out_channels=3, 
        kernel_size=(3, 3, 3),
        stride=(1, 1, 1),
        padding=(0, 1, 1)),
      nn.BatchNorm3d(3),
      nn.LeakyReLU(),
      nn.Dropout(0.2),
      nn.ConvTranspose3d(
        in_channels=3, 
        out_channels=5, 
        kernel_size=(5, 3, 3),
        stride=(1, 1, 1),
        padding=(0, 1, 1)),
      nn.BatchNorm3d(5),
      nn.LeakyReLU(),
      nn.Dropout(0.2),
      nn.ConvTranspose3d(
        in_channels=5, 
        out_channels=5, 
        kernel_size=(5, 3, 3),
        stride=(1, 1, 1),
        padding=(0, 1, 1)),
      nn.BatchNorm3d(5),
      nn.LeakyReLU(),
      nn.Dropout(0.2),
      nn.ConvTranspose3d(
        in_channels=5, 
        out_channels=7, 
        kernel_size=(3, 3, 3),
        stride=(2, 1, 1),
        padding=(0, 1, 1)),
      nn.BatchNorm3d(7)
    )  

  def encode(self, x):
    return self.encoder(x)
  
  def decode(self, x):
    return self.decoder(x)

  def forward(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

model = AutoEncoder()
model_path = "./Model/AutoEncoder 10.pth"
model.load_state_dict(torch.load(model_path, weights_only=True))
model.eval()

AutoEncoder(
  (encoder): Sequential(
    (0): Conv3d(7, 5, kernel_size=(3, 3, 3), stride=(2, 1, 1), padding=(0, 1, 1))
    (1): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Dropout(p=0.2, inplace=False)
    (4): Conv3d(5, 5, kernel_size=(5, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (5): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): LeakyReLU(negative_slope=0.01)
    (7): Dropout(p=0.2, inplace=False)
    (8): Conv3d(5, 3, kernel_size=(5, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (9): BatchNorm3d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.01)
    (11): Dropout(p=0.2, inplace=False)
    (12): Conv3d(3, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (13): BatchNorm3d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (decoder): Sequential(
    (0): Con

In [3]:
print(f"[INFO] Writing patches [", end="")

pattern = re.compile(r"Data_Cube_(\d+)\.zarr")

for i, batch in enumerate(all_dataloader):
  encoded_batch = model.encode(batch)
  encoded_batch = torch.squeeze(encoded_batch)
  encoded_batch = encoded_batch.detach().numpy()
  
  for j, latent in enumerate(encoded_batch):  
    if not (i * batch_size + j + 1) % (len(all_cubes_paths) // 100): print(".", end="")
  
    cube_path = all_cubes_paths[i * batch_size + j]
    cube_id = int(pattern.search(cube_path).group(1))
    cube = xarray.open_zarr(cube_path, consolidated=False)
    out_meta = {
      'driver': 'GTiff',
      'dtype': "float32",
      'nodata': -9999.0,
      'width': 129,
      'height': 129,
      'count': 1,
      'crs': cube.spatial_ref.crs_wkt,
      'transform': cube.rio.transform()
    }

    output_folder = f"./Data/Images/Latent Patches"
    output_file = f"{output_folder}/Latent_Patch_{cube_id}.tif"
    with rasterio.open(output_file, "w", **out_meta) as dst:
      dst.write(latent, 1)
  
print("]")

[INFO] Writing patches [....................................................................................................]
