# U-Net

In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import xarray.plot as xplt
import cftime

In [2]:
import pysteps

Pysteps configuration file found at: /home/henry/anaconda3/envs/downscaling/lib/python3.9/site-packages/pysteps/pystepsrc



In [3]:
cp_model_rotated_pole = ccrs.RotatedPole(pole_longitude=177.5, pole_latitude=37.5)
platecarree = ccrs.PlateCarree()

In [4]:
cpmdata = xr.open_mfdataset("../../../data/2.2km/rcp85/01/pr/*.nc").rename({"pr": "target_pr"})
cpmdata = cpmdata.loc[dict(ensemble_member=1, time=slice("1980-12-01","1982-11-30"))]
cpmdata = cpmdata.reset_coords()[['target_pr']]
cpmdata

Unnamed: 0,Array,Chunk
Bytes,805.58 MiB,402.79 MiB
Shape,"(720, 606, 484)","(360, 606, 484)"
Count,8 Tasks,2 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 805.58 MiB 402.79 MiB Shape (720, 606, 484) (360, 606, 484) Count 8 Tasks 2 Chunks Type float32 numpy.ndarray",484  606  720,

Unnamed: 0,Array,Chunk
Bytes,805.58 MiB,402.79 MiB
Shape,"(720, 606, 484)","(360, 606, 484)"
Count,8 Tasks,2 Chunks
Type,float32,numpy.ndarray


In [5]:
regridded_gcmdata = xr.open_mfdataset('../../../derived_data/60km/rcp85/01/*/day/*.nc')
regridded_gcmdata = regridded_gcmdata.loc[dict(ensemble_member=1, time=slice("1980-12-01","1982-11-30"))]
regridded_gcmdata = regridded_gcmdata.reset_coords()[['pr', 'psl']]
regridded_gcmdata

Unnamed: 0,Array,Chunk
Bytes,805.58 MiB,805.58 MiB
Shape,"(720, 606, 484)","(720, 606, 484)"
Count,3 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 805.58 MiB 805.58 MiB Shape (720, 606, 484) (720, 606, 484) Count 3 Tasks 1 Chunks Type float32 numpy.ndarray",484  606  720,

Unnamed: 0,Array,Chunk
Bytes,805.58 MiB,805.58 MiB
Shape,"(720, 606, 484)","(720, 606, 484)"
Count,3 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,805.58 MiB,805.58 MiB
Shape,"(720, 606, 484)","(720, 606, 484)"
Count,3 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 805.58 MiB 805.58 MiB Shape (720, 606, 484) (720, 606, 484) Count 3 Tasks 1 Chunks Type float32 numpy.ndarray",484  606  720,

Unnamed: 0,Array,Chunk
Bytes,805.58 MiB,805.58 MiB
Shape,"(720, 606, 484)","(720, 606, 484)"
Count,3 Tasks,1 Chunks
Type,float32,numpy.ndarray


In [64]:
merged_data = xr.merge([regridded_gcmdata, cpmdata])
# select a small subset of the data for trail purposes
merged_data = merged_data.isel({"grid_latitude": slice(512), "grid_longitude": slice(512)})

# split training/test based on date
training_data = merged_data.sel({"time": slice("1980-12-01", "1981-11-30")})
validation_data = merged_data.sel({"time": slice("1981-12-01", "1982-05-30")})
test_data = merged_data.sel({"time": slice("1982-06-01", "1982-11-30")})

In [65]:
training_data

Unnamed: 0,Array,Chunk
Bytes,340.31 MiB,340.31 MiB
Shape,"(360, 512, 484)","(360, 512, 484)"
Count,5 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 340.31 MiB 340.31 MiB Shape (360, 512, 484) (360, 512, 484) Count 5 Tasks 1 Chunks Type float32 numpy.ndarray",484  512  360,

Unnamed: 0,Array,Chunk
Bytes,340.31 MiB,340.31 MiB
Shape,"(360, 512, 484)","(360, 512, 484)"
Count,5 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,340.31 MiB,340.31 MiB
Shape,"(360, 512, 484)","(360, 512, 484)"
Count,5 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 340.31 MiB 340.31 MiB Shape (360, 512, 484) (360, 512, 484) Count 5 Tasks 1 Chunks Type float32 numpy.ndarray",484  512  360,

Unnamed: 0,Array,Chunk
Bytes,340.31 MiB,340.31 MiB
Shape,"(360, 512, 484)","(360, 512, 484)"
Count,5 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,340.31 MiB,340.31 MiB
Shape,"(360, 512, 484)","(360, 512, 484)"
Count,11 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 340.31 MiB 340.31 MiB Shape (360, 512, 484) (360, 512, 484) Count 11 Tasks 1 Chunks Type float32 numpy.ndarray",484  512  360,

Unnamed: 0,Array,Chunk
Bytes,340.31 MiB,340.31 MiB
Shape,"(360, 512, 484)","(360, 512, 484)"
Count,11 Tasks,1 Chunks
Type,float32,numpy.ndarray


In [66]:
validation_data

Unnamed: 0,Array,Chunk
Bytes,170.16 MiB,170.16 MiB
Shape,"(180, 512, 484)","(180, 512, 484)"
Count,5 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 170.16 MiB 170.16 MiB Shape (180, 512, 484) (180, 512, 484) Count 5 Tasks 1 Chunks Type float32 numpy.ndarray",484  512  180,

Unnamed: 0,Array,Chunk
Bytes,170.16 MiB,170.16 MiB
Shape,"(180, 512, 484)","(180, 512, 484)"
Count,5 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,170.16 MiB,170.16 MiB
Shape,"(180, 512, 484)","(180, 512, 484)"
Count,5 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 170.16 MiB 170.16 MiB Shape (180, 512, 484) (180, 512, 484) Count 5 Tasks 1 Chunks Type float32 numpy.ndarray",484  512  180,

Unnamed: 0,Array,Chunk
Bytes,170.16 MiB,170.16 MiB
Shape,"(180, 512, 484)","(180, 512, 484)"
Count,5 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,170.16 MiB,170.16 MiB
Shape,"(180, 512, 484)","(180, 512, 484)"
Count,11 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 170.16 MiB 170.16 MiB Shape (180, 512, 484) (180, 512, 484) Count 11 Tasks 1 Chunks Type float32 numpy.ndarray",484  512  180,

Unnamed: 0,Array,Chunk
Bytes,170.16 MiB,170.16 MiB
Shape,"(180, 512, 484)","(180, 512, 484)"
Count,11 Tasks,1 Chunks
Type,float32,numpy.ndarray


In [67]:
test_data

Unnamed: 0,Array,Chunk
Bytes,170.16 MiB,170.16 MiB
Shape,"(180, 512, 484)","(180, 512, 484)"
Count,5 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 170.16 MiB 170.16 MiB Shape (180, 512, 484) (180, 512, 484) Count 5 Tasks 1 Chunks Type float32 numpy.ndarray",484  512  180,

Unnamed: 0,Array,Chunk
Bytes,170.16 MiB,170.16 MiB
Shape,"(180, 512, 484)","(180, 512, 484)"
Count,5 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,170.16 MiB,170.16 MiB
Shape,"(180, 512, 484)","(180, 512, 484)"
Count,5 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 170.16 MiB 170.16 MiB Shape (180, 512, 484) (180, 512, 484) Count 5 Tasks 1 Chunks Type float32 numpy.ndarray",484  512  180,

Unnamed: 0,Array,Chunk
Bytes,170.16 MiB,170.16 MiB
Shape,"(180, 512, 484)","(180, 512, 484)"
Count,5 Tasks,1 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,170.16 MiB,170.16 MiB
Shape,"(180, 512, 484)","(180, 512, 484)"
Count,11 Tasks,1 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 170.16 MiB 170.16 MiB Shape (180, 512, 484) (180, 512, 484) Count 11 Tasks 1 Chunks Type float32 numpy.ndarray",484  512  180,

Unnamed: 0,Array,Chunk
Bytes,170.16 MiB,170.16 MiB
Shape,"(180, 512, 484)","(180, 512, 484)"
Count,11 Tasks,1 Chunks
Type,float32,numpy.ndarray


In [103]:
from unet import UNet

In [104]:
model = UNet(2, 1)
model

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, moment

In [82]:
import torch

In [101]:
img_size = 16
training_tensor = torch.stack((torch.from_numpy(training_data.isel({"grid_latitude": slice(img_size), "grid_longitude": slice(img_size)}).isel({"time": slice(2)}).psl.values),  torch.from_numpy(training_data.isel({"grid_latitude": slice(img_size), "grid_longitude": slice(img_size)}).isel({"time": slice(2)}).pr.values)), dim=1)

In [102]:
model(training_tensor)

tensor([[[[ 0.6202,  0.5706,  0.2214,  0.1205,  0.1891,  0.2663,  0.4527,
            0.5021,  0.4242,  0.3176,  0.3444,  0.2189,  0.1755,  0.6847,
            0.4419,  0.0890],
          [ 0.2414,  0.3094, -0.4780, -0.1083, -0.1397, -0.0291,  0.1862,
            0.2579,  0.1034, -0.1111, -0.2812, -0.4519, -0.0406,  0.3725,
            0.2070, -0.5428],
          [ 0.0156,  0.2379,  0.1201, -0.8489, -0.4812, -0.2933, -0.3216,
           -0.2264, -0.3710, -0.5510, -0.9947, -0.9932, -0.9958, -0.7180,
           -0.3610, -0.9370],
          [-0.1934,  0.0385,  0.2049, -0.2464, -0.4022, -0.0401, -0.1055,
           -0.2441, -0.2387, -0.3251, -0.6251, -0.4471, -0.1314,  0.2599,
            0.4373, -0.7219],
          [ 0.0938,  0.3708, -0.3274, -0.3042, -0.2572, -0.0986,  0.0157,
           -0.1350, -0.2778, -0.5033, -0.5006, -0.2677,  0.1657,  0.6060,
            0.4405, -0.6405],
          [ 0.1727,  0.3765, -0.2777,  0.0020,  0.0111, -0.0769, -0.1615,
           -0.1866, -0.2810, -0.4538