In [1]:
from matplotlib import pyplot as plt
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import mlflow
import mlflow.pytorch
import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim


* 'schema_extra' has been renamed to 'json_schema_extra'


In [2]:
import torch


class UAV_vit(pl.LightningModule):
    def __init__(self, backbone, loss_fn, labels):
        super().__init__()
        
        self.loss_fn = loss_fn
        self.model = backbone

        # Dictionary { 'train': [], 'val': [], 'test': [] }
        self.labels = labels

        # Get the number of input features of the last layer of the backbone
        num_input_filters = backbone.heads[0].in_features
        num_output_values = 1

        # Replace the head of the model
        self.model.heads = nn.Linear(in_features=num_input_filters, out_features=num_output_values).float()

        self.training_outputs = []
        self.validation_outputs = []
        self.testing_outputs = []

        self.losses = {'train': [], 'val': [], 'test': []}
        self.r2_scores = {'train': [], 'val': [], 'test': []}
        self.maes = {'train': [], 'val': [], 'test': []}
        self.rmses = {'train': [], 'val': [], 'test': []}
        self.residuals = {'train': [], 'val': [], 'test': []}
        self.predicted_values = {'train': [], 'val': [], 'test': []}

        # List of dictionaries { 'loss': [], 'r2': [], 'mae': [], 'rmse': [] }
        self.epoch_metrics = []

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        return optimizer
    
    # Hooks     

    # Training

    def on_phase_start(self):
        print(f"\n{self.phase} on_phase_start (labels {len(self.labels[self.phase])}): ")

    def on_phase_step(self):
        print(f"{self.phase} on_phase_step: ")
    
    def on_phase_end(self):
        print(f"{self.phase} on_phase_end: ")

    def on_fit_end(self) -> None:
        print(f" {self.phase} on_fit_end: ")

    # Training

    def on_train_start(self):
        self.phase = "train"

    def training_step(self, batch, batch_idx):
        outputs, labels, loss = self.get_batch_data(batch)
        self.training_outputs.append(outputs)
        r2, mae, rmse = self.get_metrics(outputs, labels)
        self.log_step_metrics(batch_idx, outputs, labels, {"loss": loss, "r2": r2, "mae": mae, "rmse": rmse})
        return loss

    # Validation

    def on_validation_start(self):
        self.phase = "val"

    def validation_step(self, batch, batch_idx):
        outputs, labels, loss = self.get_batch_data(batch)
        self.validation_outputs.append(outputs)
        r2, mae, rmse = self.get_metrics(outputs, labels)
        self.log_step_metrics(batch_idx, outputs, labels, {"loss": loss, "r2": r2, "mae": mae, "rmse": rmse})

    # Testing

    def on_test_start(self):
        self.phase = "test"

    def test_step(self, batch, batch_idx):
        outputs, labels, loss = self.get_batch_data(batch)
        self.testing_outputs.append(outputs)
        r2, mae, rmse = self.get_metrics(outputs, labels)
        self.log_step_metrics(batch_idx, outputs, labels, {"loss": loss, "r2": r2, "mae": mae, "rmse": rmse})

    # Helper functions

    def get_batch_data(self, batch):
        images, labels = batch
        labels = labels.view(-1, 1)
        outputs = self.model(images)
        loss = self.loss_fn(outputs, labels)

        return outputs, labels, loss

    def clear_metrics(self):
        self.losses[self.phase] = []
        self.rmses[self.phase] = []
        self.r2_scores[self.phase] = []
        self.maes[self.phase] = []
        self.residuals[self.phase] = []
        self.predicted_values[self.phase] = []
  
    def get_metrics(self, outputs, labels):
        outputs = outputs.cpu().detach().numpy()
        labels = labels.cpu().numpy()

        r2 = r2_score(labels, outputs)
        mae = mean_absolute_error(labels, outputs)
        rmse = mean_squared_error(labels, outputs, squared=False)
        return r2, mae, rmse

    # Logging

    def log_step_metrics(self, batch_idx, outputs, labels, metrics):
        # Flatten the lists of single items' lists
        residuals = [x.item() for residuals_sublist in (labels - outputs) for x in residuals_sublist]
        predicted_values = [x.item() for predicted_sublist in outputs for x in predicted_sublist]

        # batch_size values for each step
        self.residuals[self.phase].append(residuals)
        self.predicted_values[self.phase].append(predicted_values)

        loss = metrics["loss"].item()

        # single value for each step

        self.losses[self.phase].append(loss)
        self.r2_scores[self.phase].append(metrics["r2"])
        self.maes[self.phase].append(metrics["mae"])
        self.rmses[self.phase].append(metrics["rmse"])

        self.log("loss", loss, on_epoch=True, logger=True)
        self.log("r2", metrics["r2"], on_epoch=True, logger=True)
        self.log("mae", metrics["mae"], on_epoch=True, logger=True)
        self.log("rmse", metrics["rmse"], on_epoch=True, logger=True)

    def log_epoch_results(self) -> None:
        # print(f"{self.phase} log_epoch_results")
        
        epoch_losses = [x["loss"] for x in self.epoch_metrics]
        epoch_r2 = [x["r2"] for x in self.epoch_metrics]
        epoch_mae = [x["mae"] for x in self.epoch_metrics]
        epoch_rmse = [x["rmse"] for x in self.epoch_metrics]

        avg_loss = sum(epoch_losses) / len(epoch_losses)
        avg_r2 = sum(epoch_r2) / len(epoch_r2)
        avg_mae = sum(epoch_mae) / len(epoch_mae)
        avg_rmse = sum(epoch_rmse) / len(epoch_rmse)

        mlflow.log_metric(f"{self.phase}_loss", avg_loss, step=self.current_epoch)
        mlflow.log_metric(f"{self.phase}_r2", avg_mae, step=self.current_epoch)
        mlflow.log_metric(f"{self.phase}_mae", avg_rmse, step=self.current_epoch)
        mlflow.log_metric(f"{self.phase}_rmse", avg_r2, step=self.current_epoch)

        self.epoch_metrics = []

    # Visualization

    def create_scatterplots(self):

        losses = [x for losses_sublist in self.losses[self.phase] for x in losses_sublist]

        print(f"losses", losses)
        print(f"r2 scores", self.r2_scores[self.phase])

        plt.scatter(self.losses[self.phase], self.r2_scores[self.phase], label=self.phase, alpha=0.5)
        plt.xlabel('Loss')
        plt.ylabel('R^2')
        plt.title(f'{self.phase}: Scatter Plot: Loss vs R^2')
        plt.show()

    def plot_residual_distribution(self):
        plt.hist(self.residuals[self.phase])
        plt.title(f'{self.phase}: Distribution of Residuals')
        plt.legend()
        plt.show()