In [1]:
import sys
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import random
import torch

# %matplotlib ipympl
# %matplotlib inline
%matplotlib widget

In [2]:
# Local
cwd = pathlib.Path().resolve()
src = cwd.parent
root = src.parent
sys.path.append(str(src))
sys.path.append(str(root))

from utils.watertopo import WaterTopo
from utils.simulation import Simulation
from utils.utils import count_parameters, mse_per_timestep, recursive_pred
from utils.plot import compare_simulations_slider

from models.unet_mask import UNet_mask

In [3]:
#initialize GPU -  In case of windows use cuda instead of nps
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("Is CUDA enabled?",torch.cuda.is_available())
print("Number of GPUs",torch.cuda.device_count())

Is CUDA enabled? True
Number of GPUs 1


In [4]:
# Initialize the model

best_model = UNet_mask(2, [32, 64], 1, 10)


In [5]:
# Load the parameters from training
load_path = "../results/trained_models/unet_mask/"

best_model.load_state_dict(torch.load(load_path + "unet_32_64_orig_data80_skip5_hardmask5", map_location="cpu"))
best_model.eval()

UNet_mask(
  (encoder): Encoder(
    (in_layer): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(2, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (enc_blocks): ModuleList(
      (0): Down(
        (maxpool_conv): Sequential(
          (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (1): DoubleConv(
            (double_conv): Sequential(
              (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(in

In [6]:
grid_size = 64
skip = 5

label = ["target", "5-skip hard mask 5 pixels"]

sim = WaterTopo.load_simulations(str(root)+"/data/normalized_data/test1", 
                                 1, grid_size, 
                                 use_augmented_data=True)[0]

topo = sim.topography.reshape([1, grid_size, grid_size])
wd_0 = sim.wd[0].reshape([1, grid_size, grid_size])
X = np.concatenate([topo, wd_0])

wds = []
wds.append(sim.wd)

with torch.no_grad():
    
    best_model.train(False)

    if skip:
        sim_skips = sim.implement_skips(skip)

    outputs = recursive_pred(best_model, X, sim_skips.wd.shape[0]-1, include_first_timestep=True)
    wds.append(outputs.detach().numpy())
            
    slider5 = compare_simulations_slider(wds, label)

: 

In [None]:
# # Let's try the model recursively!
# grid_size = 64
# channels = 2
# sim_length = 97

# # Load in a random simulation
# sim = WaterTopo.load_simulations(str(root)+"/data/normalized_data/test1/", 1, grid_size, use_augmented_data=True)[0]

# # Select the time step where you want to start
# id = 0

# # Get de inputs and targets
# inputs = np.zeros((1, channels, grid_size, grid_size))
# inputs[0, 0, :, :] = sim.topography
# inputs[0, 1, :, :] = sim.return_timestep(id)
# inputs = torch.tensor(inputs, dtype=torch.float32).to(device)
# inputs.cpu()

# targets = sim.wd
# targets = torch.tensor(targets, dtype=torch.float32).to(device)
# targets[0,:,:] = inputs[0,1,:,:]

# # initialize the outputs
# outputs = torch.zeros(targets.shape).to(device)
# outputs[0,:,:] = inputs[0,1,:,:]

# # run the model
# for t in range(1, sim_length):
#    outputs[t,:,:] = best_model(inputs)
#    inputs[0,1,:,:] = outputs[t,:,:]

# mse = mse_per_timestep(targets, outputs)

In [None]:
# # Let's plot the MSE
# fig, ax = plt.subplots()
# ax.plot(np.arange(0, len(mse)), mse, label="MSE")
# ax.set_xlabel("time steps")
# ax.set_ylabel("MSE")
# ax.set_title("... - MSE per timestep")

In [None]:
# # Anim
# ani1 = WaterTopo(topo, outputs.cpu().detach().numpy()).plot_animation()
# ani2 = WaterTopo(topo, targets.cpu().detach().numpy()).plot_animation()