# Temperature & salinity from a dense feed-forward network

In [None]:
import pathlib
from typing import Callable, Optional

import pandas as pd
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import scipy.stats
import torch

%load_ext lab_black
%load_ext tensorboard

PATH_TO_CSV = pathlib.Path("dataset.csv")

BANDS = [
    "B1",
    "B2",
    "B3",
    "B4",
    "B5",
    "B6",
    "B7",
    "B8",
    "B8A",
    "B9",
    "B10",
    "B11",
    "B12",
]

pd.read_csv(PATH_TO_CSV, index_col=0).shape

## Data handling

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(
        self,
        target: str,
        bands: list[str],
        train_batch_size: int,
        preprocessing: Callable[[pd.DataFrame], pd.DataFrame],
        train_frac: float = 0.7,
        seed: int = 123456788,
    ):
        super().__init__()
        self.target = target
        self.bands = bands
        self.train_batch_size = train_batch_size
        self.preprocessing = preprocessing
        self.train_frac = train_frac
        self.seed = seed

        assert target not in bands, f"target {target} should not be present in bands"
        assert self.preprocessing(
            pd.DataFrame(torch.rand(10, 4).numpy())
        ).shape == torch.Size([10, 4])

    def setup(self, stage: str):
        """Read the data and create perform the train/val/test split."""
        data = pd.read_csv(PATH_TO_CSV, index_col=0)
        bands_data = self.preprocessing(data[self.bands])
        target_data = self.preprocessing(data[self.target])
        full_dataset = torch.utils.data.TensorDataset(
            torch.from_numpy(bands_data.values),
            torch.from_numpy(target_data.values).view(-1, 1),
        )

        n_full = len(full_dataset)
        n_train = int(n_full * self.train_frac)
        n_val = int((n_full - n_train) // 2)  # equal sized validation & test
        n_test = n_full - n_train - n_val

        (
            self.train_dataset,
            self.val_dataset,
            self.test_dataset,
        ) = torch.utils.data.random_split(
            full_dataset,
            (n_train, n_val, n_test),
            generator=torch.Generator().manual_seed(self.seed),
        )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset, batch_size=self.train_batch_size
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset, batch_size=len(self.val_dataset)
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset, batch_size=len(self.test_dataset)
        )

## Simple dense network

In [None]:
class DenseNet(pl.LightningModule):
    def __init__(
        self,
        in_dims: int,
        hidden_dims: list[int],
        activation: type[torch.nn.Module],
        final_activation: type[torch.nn.Module] = torch.nn.Identity,
        out_dims: int = 1,
    ):
        super().__init__()

        dims = [in_dims, *hidden_dims, out_dims]
        activations = [activation for _ in hidden_dims] + [final_activation]

        layers = []
        for d_in, d_out, f_act in zip(dims[:-1], dims[1:], activations):
            layers.append(torch.nn.Linear(d_in, d_out))
            layers.append(f_act())

        self.network = torch.nn.Sequential(*layers)
        self.loss_func = torch.nn.functional.mse_loss

    def forward(self, x: torch.Tensor):
        x = x.view(x.shape[0], -1).float()
        return self.network(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True)
        return dict(
            optimizer=optimizer,
            lr_scheduler=dict(
                scheduler=scheduler,
                monitor="val_loss",
            ),
        )

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.forward(x)
        return self.loss_func(y_pred, y.float())

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.forward(x)
        val_loss = self.loss_func(y_pred, y.float())
        self.log("val_loss", val_loss)
        return val_loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.forward(x)
        return self.loss_func(y_pred, y.float())

## Plotting

In [None]:
def sample(dm: DataModule, network: DenseNet, dataset: str):
    x, y = next(iter(getattr(dm, f"{dataset}_dataloader")()))
    y = y.view(-1).detach().numpy()
    y_pred = network(x).view(-1).detach().numpy()
    return y, y_pred


def plot(dm: DataModule, network: DenseNet):
    # plot entire training set - adjust batch size
    tmp = dm.train_batch_size
    dm.train_batch_size = len(dm.train_dataset)

    fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(16, 8))

    for ax, dataset in zip(axes, ("train", "val", "test")):
        y, y_pred = sample(dm, network, dataset)
        slope, intercept, r, p, se = scipy.stats.linregress(y, y_pred)

        # ax.hexbin(y, y_pred, gridsize=30)
        ax.scatter(y, y_pred, s=0.2)
        ax.axline(xy1=(0, intercept), slope=slope, color="red", label=f"r = {r:.2g}")

        ax.set_title(f"{dataset} : {len(y)} points")
        ax.set_xlabel("target")
        ax.set_ylabel("prediction")

        ax.set_xlim(min(y.min(), y_pred.min()), max(y.max(), y_pred.max()))
        ax.set_ylim(min(y.min(), y_pred.min()), max(y.max(), y_pred.max()))
        ax.set_aspect("equal")
        ax.legend()

    return fig

## Experiment: Temperature (SST)

In [None]:
dm = DataModule(
    bands=BANDS,
    target="TEMP",
    train_batch_size=32,
    preprocessing=lambda df: (df - df.mean()) / df.std(),  # standardise
)
network = DenseNet(
    in_dims=len(BANDS), hidden_dims=[16, 32, 64, 32, 16], activation=torch.nn.ReLU
)
trainer = pl.Trainer(
    max_epochs=50,
    val_check_interval=10,
    progress_bar_refresh_rate=0,
    callbacks=[
        pl.callbacks.early_stopping.EarlyStopping(
            monitor="val_loss", patience=10, verbose=False
        )
    ],
)
trainer.fit(network, dm)

fig = plot(dm, network)
plt.show()

## Experiment: Salinity

In [None]:
dm = DataModule(
    bands=BANDS,
    target="PSAL",
    train_batch_size=32,
    preprocessing=lambda df: (df - df.mean()) / df.std(),  # standardise
)
network = DenseNet(
    in_dims=len(BANDS), hidden_dims=[16, 32, 64, 32, 16], activation=torch.nn.ReLU
)
trainer = pl.Trainer(
    max_epochs=50,
    val_check_interval=1,
    progress_bar_refresh_rate=0,
    callbacks=[
        pl.callbacks.early_stopping.EarlyStopping(
            monitor="val_loss", patience=10, verbose=False
        )
    ],
)
trainer.fit(network, dm)

fig = plot(dm, network)
plt.show()

In [None]:
%tensorboard --logdir lightning_logs/