# Train an neural network for regression

- Watch this video https://lightning.ai/docs/pytorch/stable/starter/introduction.html

In [None]:
# Replace this test data with your data and adapt the code accordingly

import pandas as pd

df = pd.read_json(
    "https://raw.githubusercontent.com/kuennethgroup/materials_datasets/refs/heads/main/polymer_tendency_to_crystalize/polymers_tend_to_crystalize.json"
)
df

In [None]:
import torch, torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
import lightning as L

# --------------------------------
# Define a LightningModule by subclassing LightningModule
# A LightningModule is a subclass of nn.Module


class LitRegressor(L.LightningModule):
    def __init__(self):
        # -- Define your NN
        super().__init__()
        input_size = 2048
        self.l1 = nn.Sequential(nn.Linear(input_size, 300), nn.ReLU(), nn.Dropout(0.4))
        self.l2 = nn.Sequential(nn.Linear(300, 100), nn.ReLU(), nn.Dropout(0.4))
        self.l3 = nn.Sequential(nn.Linear(100, 1))

    def forward(self, x):
        # -- Define how to ho forward through your NN
        # forward defines the prediction/inference actions
        x = self.l3(self.l2(self.l1(x)))
        return x

    def training_step(self, batch, batch_idx):
        # -- Define how to do a training step
        # Split to input and output as you defined it in your dataloader
        x, y = batch
        # fog forward and get prediction
        x = self.forward(x)
        # compute loss
        y = y.view(x.size(0), -1)
        loss = F.mse_loss(y, x)
        # log the loss (for plotting later)
        self.log("train_loss", loss, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        # -- Define how to do a validation step (similar to raining)
        x, y = batch
        x = self.forward(x)
        y = y.view(x.size(0), -1)
        loss = F.mse_loss(x, y)
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        # Define the optimize that you want to use
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

## Load data

In [None]:
df

In [None]:
import torch.utils.data as data_utils
import numpy as np
from sklearn.model_selection import train_test_split

# split
train, val = train_test_split(df, shuffle=True, random_state=123)

# Train
train_fps = torch.Tensor(np.stack(train["fingerprint"].values).astype(np.float32))
train_values = torch.Tensor(np.stack(train["value"].values).astype(np.float32))
train_data = data_utils.TensorDataset(train_fps, train_values)
train_loader = data_utils.DataLoader(train_data, batch_size=30, shuffle=False)

# Validation
val_fps = torch.Tensor(np.stack(val["fingerprint"].values).astype(np.float32))
val_values = torch.Tensor(np.stack(val["value"].values).astype(np.float32))
val_data = data_utils.TensorDataset(val_fps, val_values)
val_loader = data_utils.DataLoader(val_data, batch_size=30, shuffle=False)

# Test
# test loader, we use the same dataset as for val (for now); no values here
test_fps = torch.Tensor(np.stack(val["fingerprint"].values).astype(np.float32))
test_values = torch.Tensor(np.stack(val["value"].values).astype(np.float32))
test_loader = data_utils.DataLoader(test_fps, batch_size=30, shuffle=False)

# Train

In [None]:
from lightning.pytorch.callbacks import ModelSummary
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger


from pathlib import Path
import shutil


save_dir = Path("my_pytorch_model")
if save_dir.exists():
    shutil.rmtree(save_dir)

callbacks = [
    ModelCheckpoint(dirpath=save_dir, save_top_k=1, monitor="val_loss", verbose=True),
    ModelSummary(max_depth=-1),
    EarlyStopping(monitor="val_loss", mode="min", verbose=True),
]


regressor = LitRegressor()
trainer = L.Trainer(
    max_epochs=50,
    log_every_n_steps=1,
    val_check_interval=1,
    callbacks=callbacks,
    logger=[CSVLogger(".")],
)
trainer.fit(model=regressor, train_dataloaders=train_loader, val_dataloaders=val_loader)

## Plot

In [None]:
df_res = pd.read_csv(Path(trainer.logger.log_dir) / "metrics.csv")
df_res = df_res.set_index(["epoch", "step"])
df_res = pd.concat([df_res["train_loss"].dropna(), df_res["val_loss"].dropna()], axis=1)
df_res.plot()

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import root_mean_squared_error, r2_score

fig, ax = plt.subplots()
best_regressor = LitRegressor.load_from_checkpoint(list(save_dir.glob("*.ckpt"))[0])

trainer = L.Trainer()
preds = trainer.predict(best_regressor, test_loader)
preds = torch.cat(preds).squeeze().numpy()

rmse = root_mean_squared_error(preds, test_values)
r2 = r2_score(preds, test_values)

ax.plot(preds, test_values, "o")
ax.plot([0, 100], [0, 100], "k--")
ax.set_ylabel("true")
ax.set_xlabel("pred")
print(f"{rmse = } [%]")
print(f"{r2 = }")