In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pyproj

pd.set_option("display.max_columns", 500)
gpu = torch.device("cuda:0")

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
ds = xr.open_dataset(
    "/home/knowit/Home_Foresee/forseeModel/data/copernicus/datasets/norway_nrt.nc"
)
ds

In [None]:
ds.keys()

In [None]:
from training.dataloader import ForSeaDataset


# ocean_data_path = '/home/knowit/Home_Foresee/forseeModel/data/copernicus/nrt/resampled/all_vars.nc'
ocean_data_path = (
    "/home/knowit/Home_Foresee/forseeModel/data/copernicus/datasets/ocean_data.nc"
)
route_data_path = (
    "/home/knowit/Home_Foresee/forseeModel/data/VMS_DCA_joined/cod_trawl.parquet"
)
dataset = ForSeaDataset(
    ocean_data_path, route_data_path, log_target=log_target, batched=True
)
# dataloader = DataLoader(dataset, batch_size=3, shuffle=True)

In [None]:
plt.figure(figsize=(16, 4))
num_features = dataset.route_input.shape[1]
for i in range(num_features):
    plt.subplot(1, num_features, i + 1)
    plt.hist(dataset.route_input[:, i], bins=20)
plt.show()

In [None]:
(X1, X2), y = dataset.__getitem__(0)

In [None]:
X1.shape, X2.shape, y.shape

In [None]:
import torch
from models.EncoderDecoder import ForseaAutoEncoder

ocean_data_shape = X1.shape
route_data_features = X2.shape[1]

model = ForseaAutoEncoder(
    ocean_data_shape,
    [16, 16, 16],
    (3, 3),
    route_data_features,
    128,
    1,
    log_target=log_target,
).cuda()

In [None]:
import torch.optim as optim

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
model.train()
loss_history = []
for epoch in range(5):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(dataset):
        # get the inputs; data is a list of [inputs, labels]
        (ocean_input, route_input), roundweight = data
        if len(route_input) == 0:
            continue
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(ocean_input, route_input)
        loss = criterion(outputs, roundweight)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        print_period = 200
        if i % print_period == print_period - 1:  # print every 2000 mini-batches
            print(
                f"[{epoch + 1}, {i + 1:5d}] | output: {torch.mean(outputs)} | loss: {running_loss / print_period:.3f}"
            )
            loss_history.append(running_loss / print_period)
            running_loss = 0.0

print("Finished Training")

In [None]:
history = np.array(loss_history)
plt.plot(history)
plt.show()

In [None]:
x = dataset.ocean_data.x.values
y = dataset.ocean_data.y.values
X, Y = np.meshgrid(x, y)
X_flat = X.ravel()
Y_flat = Y.ravel()

In [None]:
inference_date = pd.to_datetime("2022-09-01 12:00:00")
time_of_day = 0.3
day_of_year = inference_date.day_of_year / 365

In [None]:
inference_array = np.zeros((len(X_flat), route_data_features))
inference_array[:, 0] = time_of_day
inference_array[:, 1] = day_of_year
inference_array[:, 2] = X_flat
inference_array[:, 3] = Y_flat
inference_tensor = torch.from_numpy(inference_array).float().cuda()

In [None]:
ocean_array = np.zeros(ocean_data_shape)
for i, k in enumerate(dataset.ocean_data.keys()):
    ocean_array[:, i] = dataset.ocean_data[k].sel(time=inference_date).values
    ocean_array[:, i] = (ocean_array[:, i] - dataset.ocean_min[k]) / (
        dataset.ocean_max[k] - dataset.ocean_min[k]
    )
ocean_tensor = torch.from_numpy(ocean_array)
ocean_tensor = torch.nan_to_num(ocean_tensor)
ocean_tensor = ocean_tensor.float().cuda()

In [None]:
model.eval()
pred = model(ocean_tensor, inference_tensor)

In [None]:
pred = pred.cpu()

In [None]:
plt.imshow(pred.reshape(X.shape).detach().numpy())
plt.show()