# Load packages

In [1]:
import xarray as xr
import torch
import plotly.express as px
from torchvision.io import write_png
import torch.nn as nn

# Torch cuda check

In [2]:
print(torch.__version__)
# Cuda 12.1

2.1.0+cu121


In [10]:
torch.cuda.is_available()
torch.cuda.device_count()
torch.cuda.current_device()
torch.cuda.mem_get_info()

(3955884032, 25385107456)

# Load data

In [11]:
bm_xr = xr.load_dataset("~/data/nsidc/BedMachineAntarctica-v3.nc")

In [14]:
train_y_max = 1800000 - 500 # 1900 km, -500 for midpoints
train_y_min = -300000 # - 300 km

train_x_max = 850000 - 500 # 900 km
train_x_min = -200000 # - 200 km

y_range = train_y_max - train_y_min
x_range = train_x_max - train_x_min

print("Y range in km", y_range/1000)
print("X range in km", x_range/1000)

# 2450 training images
(train_x_max - train_x_min + 500) * (train_y_max - train_y_min + 500) / 900000

Y range in km 2099.5
X range in km 1049.5


2450000.0

In [12]:
# Function: converts from [0, 755] to {r, g, b} however some are not valid to convert back

def cont_755_to_rbb(input):
    # input: torch.Size([2450, 1, 60, 60])
    # 0 - 255:
    red = torch.where(condition = input < (255), input = input.int(), other = 255)
    # 255 - 510:
    green = torch.where(condition = (input > (255)), input = input.int() - 255, other = 0) # subtract 255
    green = torch.where(condition = green < 255, input = green.int(), other = 255)
    # 510 - 765:
    blue = torch.where(condition = input > (255*2), input = (input - (255*2)).int(), other = 0)

    rgb = torch.cat([red, green, blue], dim = 1)
    
    return(rgb.type(torch.uint8))

In [15]:
def export_train_images(bm_xr, d_y_min, d_y_max, d_x_min, d_x_max):

    train_bm_xr = bm_xr.sel(x = slice(d_x_min, d_x_max), 
                            y = slice(d_y_max, d_y_min))
    # Print to check dims
    print(train_bm_xr)

    train_tensor = torch.tensor(train_bm_xr.bed.values).unsqueeze(0)

    IMAGE_DIM = 60
    N_ROW_IMAGES = int(train_bm_xr.bed.values.shape[0]/IMAGE_DIM)
    N_COLUMN_IMAGES = int(train_bm_xr.bed.values.shape[1]/IMAGE_DIM)

    image_tensor = torch.empty(size = (0, 1, IMAGE_DIM, IMAGE_DIM))

    for row in range(0, N_ROW_IMAGES):
        row_min = row * IMAGE_DIM
        row_max = row_min + IMAGE_DIM

        for column in range(0, N_COLUMN_IMAGES):
            column_min = column * IMAGE_DIM
            column_max = column_min + IMAGE_DIM

            image_tensor = torch.cat((image_tensor, train_tensor[:, row_min : row_max, column_min : column_max].unsqueeze(0)), dim = 0)

    min_values, _ = torch.min(image_tensor.reshape(image_tensor.shape[0], -1), dim = -1)
    max_values, _ = torch.max(image_tensor.reshape(image_tensor.shape[0], -1), dim = -1)
    range_values = max_values - min_values

    norm = torch.subtract(input = image_tensor.reshape(image_tensor.shape[0], -1), other = min_values.unsqueeze(1))
    norm = torch.div(input = norm, other = range_values.unsqueeze(1))
    norm = norm.reshape(image_tensor.shape)
    cont_755 = norm * (3*255)

    # Low qual
    pool = nn.AvgPool2d(4, stride = 4)
    cont_755_lq = pool(cont_755)
    
    rgb = cont_755_to_rbb(cont_755)
    rgb_lq = cont_755_to_rbb(cont_755_lq)

    n_images = rgb.shape[0]

    for i in range(n_images):
        # Create filename: string with contant length
        number = str(i)
        while(len(number) < 4):
            number = '0' + number

        name_hr = "datasets/ANT_train/ANT_train_HR_sub/" + number + ".png"
        #name_lr = "datasets/ANT_train/ANT_train_LR_sub/X4_sub/" + number + "x4" + ".png"
        name_lr = "datasets/ANT_train/ANT_train_LR_sub/X4_sub/" + number + ".png"

        # following https://data.vision.ee.ethz.ch/cvl/DIV2K/
        write_png(rgb[i, :, :, :].type(torch.uint8), filename = name_hr)
        write_png(rgb_lq[i, :, :, :].type(torch.uint8), filename = name_lr)
    
export_train_images(bm_xr, train_y_min, train_y_max, train_x_min, train_x_max)

<xarray.Dataset>
Dimensions:    (x: 2100, y: 4200)
Coordinates:
  * x          (x) int32 -200000 -199500 -199000 ... 848500 849000 849500
  * y          (y) int32 1799500 1799000 1798500 ... -299000 -299500 -300000
Data variables:
    mapping    |S1 b''
    mask       (y, x) int8 2 2 2 2 2 2 2 2 2 2 2 2 2 ... 2 2 2 2 2 2 2 2 2 2 2 2
    firn       (y, x) float32 18.97 18.98 19.0 19.01 ... 29.95 29.94 29.94 29.93
    surface    (y, x) float32 1.945e+03 1.945e+03 ... 3.284e+03 3.283e+03
    thickness  (y, x) float32 1.75e+03 1.755e+03 ... 2.586e+03 2.578e+03
    bed        (y, x) float32 194.9 190.7 188.6 188.9 ... 704.5 697.6 705.3
    errbed     (y, x) float32 34.0 32.0 30.0 30.0 30.0 ... 35.0 35.0 35.0 35.0
    source     (y, x) int8 5 5 5 5 5 5 5 5 5 5 5 5 5 ... 5 5 5 5 5 5 5 5 5 5 5 5
    dataid     (y, x) int8 0 2 0 0 0 0 2 2 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0
    geoid      (y, x) int16 12 12 12 12 12 12 12 ... -23 -23 -23 -23 -23 -23 -23
Attributes: (12/17)
    Conventions:   

# Repeat for Validation data

In [17]:
val_y_max = 1800000 - 500# 1900 km
val_y_min = -300000 # - 300 km

val_x_max = 910000 - 500 # 900 km
val_x_min = 850000 # - 200 km

val_y_range = val_y_max - val_y_min
val_x_range = val_x_max - val_x_min

print("Y range in km", val_y_range/1000)
print("X range in km", val_x_range/1000)

# 138 training images
(val_x_max - val_x_min + 500) * (val_y_max - val_y_min + 500) / 900000

Y range in km 2099.5
X range in km 59.5


140000.0

In [8]:
def make_val_data(bm_xr, d_y_min, d_y_max, d_x_min, d_x_max):

    train_bm_xr = bm_xr.sel(x = slice(d_x_min, d_x_max), 
                            y = slice(d_y_max, d_y_min))
    # Print to check dims
    print(train_bm_xr)

    train_tensor = torch.tensor(train_bm_xr.bed.values).unsqueeze(0)

    IMAGE_DIM = 60
    N_ROW_IMAGES = int(train_bm_xr.bed.values.shape[0]/IMAGE_DIM)
    N_COLUMN_IMAGES = int(train_bm_xr.bed.values.shape[1]/IMAGE_DIM)

    image_tensor = torch.empty(size = (0, 1, IMAGE_DIM, IMAGE_DIM))

    for row in range(0, N_ROW_IMAGES):
        row_min = row * IMAGE_DIM
        row_max = row_min + IMAGE_DIM

        for column in range(0, N_COLUMN_IMAGES):
            column_min = column * IMAGE_DIM
            column_max = column_min + IMAGE_DIM

            image_tensor = torch.cat((image_tensor, train_tensor[:, row_min : row_max, column_min : column_max].unsqueeze(0)), dim = 0)

    min_values, _ = torch.min(image_tensor.reshape(image_tensor.shape[0], -1), dim = -1)
    max_values, _ = torch.max(image_tensor.reshape(image_tensor.shape[0], -1), dim = -1)
    range_values = max_values - min_values

    norm = torch.subtract(input = image_tensor.reshape(image_tensor.shape[0], -1), other = min_values.unsqueeze(1))
    norm = torch.div(input = norm, other = range_values.unsqueeze(1))
    norm = norm.reshape(image_tensor.shape)
    cont_755 = norm * (3*255)

    # Low qual
    pool = nn.AvgPool2d(4, stride = 4)
    cont_755_lq = pool(cont_755)
    
    rgb = cont_755_to_rbb(cont_755)
    rgb_lq = cont_755_to_rbb(cont_755_lq)

    n_images = rgb.shape[0]

    for i in range(n_images):
        # Create filename: string with contant length
        number = str(i)
        while(len(number) < 4):
            number = '0' + number

        name_hr = "datasets/ANT_val/ANT_val_HR_sub/" + number + ".png"
        # name_lr = "datasets/ANT_val/ANT_val_LR_sub/X4_sub/" + number + "x4" + ".png"
        name_lr = "datasets/ANT_val/ANT_val_LR_sub/X4_sub/" + number + ".png"

        # following https://data.vision.ee.ethz.ch/cvl/DIV2K/
        write_png(rgb[i, :, :, :].type(torch.uint8), filename = name_hr)
        write_png(rgb_lq[i, :, :, :].type(torch.uint8), filename = name_lr)
    
make_val_data(bm_xr, val_y_min, val_y_max, val_x_min, val_x_max)

<xarray.Dataset>
Dimensions:    (x: 120, y: 4200)
Coordinates:
  * x          (x) int32 850000 850500 851000 851500 ... 908500 909000 909500
  * y          (y) int32 1799500 1799000 1798500 ... -299000 -299500 -300000
Data variables:
    mapping    |S1 b''
    mask       (y, x) int8 1 1 1 1 1 1 1 1 1 1 1 1 2 ... 2 2 2 2 2 2 2 2 2 2 2 2
    firn       (y, x) float32 0.0 0.0 0.0 0.0 0.0 ... 29.91 29.92 29.92 29.93
    surface    (y, x) float32 1.298e+03 1.361e+03 ... 3.321e+03 3.321e+03
    thickness  (y, x) float32 0.0 0.0 0.0 0.0 ... 2.61e+03 2.616e+03 2.622e+03
    bed        (y, x) float32 1.298e+03 1.361e+03 1.304e+03 ... 705.2 699.1
    errbed     (y, x) float32 10.0 10.0 10.0 10.0 ... 182.0 192.0 202.0 212.0
    source     (y, x) int8 1 1 1 1 1 1 1 1 1 1 1 1 3 ... 5 5 5 5 5 5 5 5 5 5 5 5
    dataid     (y, x) int8 1 1 1 1 1 1 1 1 1 1 1 1 0 ... 0 0 0 0 0 0 0 0 0 0 0 0
    geoid      (y, x) int16 17 17 17 17 17 17 17 ... -22 -22 -22 -22 -22 -22 -22
Attributes: (12/17)
    Convention