In [4]:
import sys
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

# Local
cwd = pathlib.Path().resolve()
src = cwd.parent
root = src.parent
sys.path.append(str(src))

from utils.utils import count_parameters, create_sequence
from utils.train import train_and_validate
from utils.watertopo import WaterTopo
from unet import UNet

In [5]:
#initialize GPU -  In case of windows use cuda instead of nps
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("Is CUDA enabled?",torch.cuda.is_available())
print("Number of GPUs",torch.cuda.device_count())
print('Using device:', device)

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Is CUDA enabled? False
Number of GPUs 0
Using device: cpu


In [7]:
model = UNet(2, 1, False)

print(f"U-Net --> num. trainable parameters:{count_parameters(model):8d}")

U-Net --> num. trainable parameters:31037057


In [5]:
sim_amount = 80
use_augmented_data = False
T = 1
H = 1
training_size = 0.8
batch_size = 10
num_epochs = 200
lr = 0.0005
criterion = nn.MSELoss()
optimizer = optim.AdamW
model_name = "unet_orig_data"

In [6]:
sims = WaterTopo.load_simulations(str(root)+"/data/normalized_data/tra_val", sim_amount=sim_amount, number_grids=64, use_augmented_data=use_augmented_data)

X, Y = create_sequence(sims, T, H)
X, Y = X.squeeze(1), Y.squeeze(1)

X, Y = X,Y

# We keep track of indexes of train and validation.
X_tra, X_tst, Y_tra, Y_tst, ix_tra, ix_tst = train_test_split(
    X, Y, np.arange(X.shape[0]), test_size=1-training_size, shuffle=True, random_state=42)

# Split the existing test dataset into validation and test sets (50/50 split)
X_val, X_tst, Y_val, Y_tst, ix_val, ix_tst = train_test_split(
    X_tst, Y_tst, ix_tst, test_size=0.5, shuffle=True, random_state=42)

#create datasets and data loaders
train_dataset = TensorDataset(torch.tensor(X_tra, dtype=torch.float32), torch.tensor(Y_tra, dtype=torch.float32))
val_dataset = TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(Y_val, dtype=torch.float32))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# defining the optimizer
optimizer = optimizer(model.parameters(), lr=lr)

# defining the save path
save_path = "../results/trained_models/" + model_name

# training
train_losses, val_losses, best_val_loss, time = train_and_validate(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, save_path)

# Load the best model
model.load_state_dict(torch.load(save_path))

KeyboardInterrupt: 

In [None]:
plt.plot(train_losses, label='Training')
plt.plot(val_losses, label='Validation')
# plt.yscale('log')
plt.title('Losses')
plt.xlabel('Epochs')
plt.legend()
plt.show()

In [None]:
model.eval()

grid_size = 64
channels = 2

f,axs = plt.subplots(2, 4, figsize=(20,10))
axs = axs.reshape(-1)

for i in range(0, len(axs), 2):
  with torch.no_grad():
      # Randomly select a simulation
      idx = random.randint(0, X.shape[0]-1)
      inputs = X[idx]
      targets = Y[idx]

      # Predict
      inputs = torch.tensor(inputs, dtype=torch.float32)
      inputs = inputs.unsqueeze(0)
      prediction = model(inputs)

      # Plotting inputs (time series)
      axs[i].imshow(targets.reshape([64, 64, 1]))
      axs[i+1].imshow(prediction.reshape([64, 64, 1]))

      axs[i].set_title(f"Target: simulation {idx}")
      axs[i+1].set_title(f"Prediction: simulation {idx}")

f.tight_layout()

In [None]:
from utils.utils import recursive_pred
from utils.utils import mse_per_timestep

outputs = recursive_pred(model, X[0],  X.shape[0])
mse = mse_per_timestep(Y, outputs)

plt.plot(mse, label='MSE')
plt.title('Mean squared error over time')
plt.xlabel('Timestep')
plt.legend()
plt.show()