In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchsummary import summary
from poutyne import Model, SKLearnMetrics
from sklearn.metrics import r2_score, explained_variance_score, mean_squared_error, median_absolute_error
from utils import metric_flatten, get_poutyne_callbacks, saferm



In [None]:
cuda_device = 0
device = torch.device("cuda:%d" % cuda_device if torch.cuda.is_available() else "cpu")

In [None]:
experiment_name = "test"

In [None]:
def func(xx,yy):
    # y = ((xx-0.5)**2 + (yy-0.25)-0.2 + (xx-0.7)**3+ yy**4 - 2*np.sin(20*((xx-0.7)**2 + (yy-0.3)**2)))/2
    y = 2*np.sin(20*((xx-0.7)**2 + (yy-0.3)**2))
    return y


In [None]:
x = np.arange(0, 1, 0.02)
y = np.arange(0, 1, 0.02)
xx, yy = np.meshgrid(x, y, sparse=False)
z = func(xx,yy)
h = plt.contourf(x, y, z)
plt.colorbar()
plt.axis('scaled')


In [None]:
## dataset creation
def sample_generator_regression(n):
    x = np.random.rand(n, 2).astype(np.float32)
    y = np.expand_dims(func(x[:,0], x[:,1]),axis=1)
    return torch.tensor(x), torch.tensor(y)

train_data = sample_generator_regression(1024)
valid_data = sample_generator_regression(64)
test_data = sample_generator_regression(128)


In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch

class LinearNet(nn.Module):
    def __init__(self, input_size = 2, output_size = 1, n_hidden=64):
        super(LinearNet, self).__init__()
        self.linear1 = nn.Linear(input_size, n_hidden)
        self.linear2 = nn.Linear(n_hidden, n_hidden)
        self.linear3 = nn.Linear(n_hidden, n_hidden)
        self.linear4 = nn.Linear(n_hidden, output_size)

    def forward(self, x):
        x1 = F.leaky_relu(self.linear1(x))
        x2 = F.leaky_relu(self.linear2(x1))
        x3 = F.leaky_relu(self.linear3(x2))
        x4 = F.leaky_relu(self.linear4(x3))
        return x4

network = LinearNet(input_size = 2, output_size = 1, n_hidden=64)

In [None]:
summary(network, (2,))

In [None]:
model = Model(network, 'adam', 'mse',
              batch_metrics=["l1"],
              epoch_metrics=[ SKLearnMetrics(metric_flatten(r2_score)), 
                              SKLearnMetrics(metric_flatten(explained_variance_score)), 
                              SKLearnMetrics(metric_flatten(mean_squared_error)), 
                              SKLearnMetrics(metric_flatten(median_absolute_error))
                            ],
              device=device)

In [None]:
callbacks, summary_dir, checkpoint_dir = get_poutyne_callbacks(experiment_name)

In [None]:
from torch.utils.data import TensorDataset
train_dataset = TensorDataset(*train_data)
valid_dataset = TensorDataset(*valid_data)
test_dataset = TensorDataset(*test_data)

In [None]:
# optimization paramters
optimization_kwargs = {}
optimization_kwargs["batch_size"] = 8
optimization_kwargs["epochs"] = 10

# # delete the folder for summaries and checkpoints
# saferm(summary_dir)
# saferm(checkpoint_dir)
# summary_dir.mkdir(parents=True, exist_ok=True)
# checkpoint_dir.mkdir(parents=True, exist_ok=True)

# train the model
history = model.fit_dataset(train_dataset, valid_dataset=valid_dataset, **optimization_kwargs, callbacks=callbacks) 

In [None]:
e = [v['epoch'] for v in history]
val_loss = [v['val_loss'] for v in history]
train_loss = [v['loss'] for v in history]
plt.plot(e, train_loss, label="Training MSE")
plt.plot(e, val_loss, label="Validation MSE")
plt.xlabel("Epochs")
plt.legend()

In [None]:
x = np.arange(0, 1, 0.02)
y = np.arange(0, 1, 0.02)
xx, yy = np.meshgrid(x, y, sparse=False)
Z1 = func(xx,yy)

X = torch.tensor(np.array([xx.reshape(-1), yy.reshape(-1)]).T.astype(np.float32))
Z2 = model.predict(X).reshape(xx.shape)

In [None]:
vmin = np.min(Z1)
vmax = np.max(Z1)
plt.figure(figsize=(10, 5))
plt.subplot(1,2,1)
h = plt.contourf(x, y, Z1, vmin=vmin, vmax=vmax)
plt.colorbar()
plt.axis('scaled')
plt.title("Ground truth")

plt.subplot(1,2,2)
h = plt.contourf(x, y, Z2, vmin=vmin, vmax=vmax)
plt.colorbar()
plt.axis('scaled')
plt.title("Neural network");