In [0]:
from math import pi
from pyspark.sql import functions as F, Row
import os
import shutil
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import gc
import logging

logging.getLogger("root").setLevel(logging.ERROR)

In [0]:
# # Number of samples
# num_samples = 25000

# # Create DataFrame with two independent uniform random columns
# df = (
#     spark.range(num_samples)
#     .withColumn("u1", F.rand(seed=42))  # Uniform for radial distance
#     .withColumn("u2", F.rand(seed=84))  # Uniform for angle
# )

# # Corrected radial coordinate calculation (density proportional to r)
# df = (
#     df.withColumn("r", F.sqrt("u1"))  # Inverse transform for density ~ r
#     .withColumn("theta", F.lit(2) * F.pi() * F.col("u2"))
#     .withColumn("abcissa", F.col("r") * F.cos("theta"))
#     .withColumn("ordinate", F.col("r") * F.sin("theta"))
#     .withColumn("y", F.col("r"))
# )

# # Show results
# df.limit(10).display()

# # Save this data to table
# df.write.saveAsTable("default.polar_coordinates_data")

# # Optional: Visual verification would show increased density near center

In [0]:
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

df = spark.table("default.polar_coordinates_data")

# Define the angular sector range (e.g., between pi/6 and pi/3)
theta_min = pi / 6
theta_max = pi / 3

# Define the radial range
radius_min = 0.0
radius_max = 1.0


# Collect the data
def sector_filter(theta_col, r_col, theta_min, theta_max, radius_min, radius_max):
    return (
        (theta_col >= theta_min)
        & (theta_col <= theta_max)
        & (r_col >= radius_min)
        & (r_col <= radius_max)
    )


training_data_all = (
    df.where(
        ~sector_filter(
            F.col("theta"), F.col("r"), theta_min, theta_max, radius_min, radius_max
        )
    )
    .select("abcissa", "ordinate", "y")
    .collect()
)

train_data, val_data = train_test_split(
    training_data_all, test_size=0.2, random_state=42
)

test_data = (
    df.where(
        sector_filter(
            F.col("theta"), F.col("r"), theta_min, theta_max, radius_min, radius_max
        )
    )
    .select("abcissa", "ordinate", "y")
    .collect()
)

In [0]:
class SparkDataFrameDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        spark_df,
        x1="abcissa",
        x2="ordinate",
        target_col="y",
        polar_coordinates=False,
    ):
        data = spark_df.select(x1, x2, target_col).toPandas()
        x = data[x1].values
        y = data[x2].values
        output = data[target_col].values
        if polar_coordinates:
            r = torch.sqrt(torch.tensor(x) ** 2 + torch.tensor(y) ** 2)
            theta = torch.atan2(torch.tensor(y), torch.tensor(x))
            self.data = torch.stack((r, theta), dim=1).float()
        else:
            self.data = torch.tensor(list(zip(x, y)), dtype=torch.float32)
        self.targets = torch.tensor(output, dtype=torch.float32).unsqueeze(1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]


df_train, df_val = df.where(
    ~sector_filter(
        F.col("theta"), F.col("r"), theta_min, theta_max, radius_min, radius_max
    )
).randomSplit([0.8, 0.2], seed=42)


df_test = df.where(
    sector_filter(
        F.col("theta"), F.col("r"), theta_min, theta_max, radius_min, radius_max
    )
)

print(df_train.count(), df_val.count(), df_test.count())

In [0]:
# Extract x, y, and output values for training data
x_values_train = df_train.select("abcissa").toPandas()["abcissa"].tolist()
y_values_train = df_train.select("ordinate").toPandas()["ordinate"].tolist()
output_values_train = df_train.select("y").toPandas()["y"].tolist()

x_values_val = df_val.select("abcissa").toPandas()["abcissa"].tolist()
y_values_val = df_val.select("ordinate").toPandas()["ordinate"].tolist()
output_values_val = df_val.select("y").toPandas()["y"].tolist()

# Extract x, y values for test data
x_values_test = df_test.select("abcissa").toPandas()["abcissa"].tolist()
y_values_test = df_test.select("ordinate").toPandas()["ordinate"].tolist()

# Plot the data with color by output
plt.figure(figsize=(8, 8), dpi=256)
plt.scatter(
    x_values_train,
    y_values_train,
    c=output_values_train,
    cmap="twilight",
    alpha=1,
    s=0.5,
    label="Training Data",
)
plt.scatter(
    x_values_test, y_values_test, color="red", alpha=1, s=0.5, label="Test Data"
)
plt.xlabel("x")
plt.ylabel("y")
plt.title("2D Scatter Plot of x and y (colored by output)")
plt.colorbar(label="output")
plt.grid(True)
plt.gca().set_aspect("equal", adjustable="box")
plt.legend()
plt.show()

In [0]:
class RegressorLinearSingleLayer(pl.LightningModule):
    def __init__(self):
        super(RegressorLinearSingleLayer, self).__init__()
        self.model = nn.Linear(2, 1)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        self.log("train_loss", loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        self.log("val_loss", loss, on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


class RegressorLinearMultiLayer(pl.LightningModule):
    def __init__(self):
        super(RegressorLinearMultiLayer, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 64), nn.ReLU(), nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 1)
        )

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        self.log("train_loss", loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        self.log("val_loss", loss, on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


def train_model(model_parameters, dataset_train, dataset_val, dataset_test):
    model = model_parameters["model"]
    trainer = pl.Trainer(
        max_epochs=model_parameters["max_epochs"],
        accelerator="auto",
        devices=1,
        callbacks=[checkpoint_callback],
    )
    trainer.fit(
        model=model, train_dataloaders=dataloader_train, val_dataloaders=dataloader_val
    )


def get_epoch_predictions(model, model_name, dataloader_test):
    checkpoint_dir = f"/Workspace/Shared/lightning_logs/{model_name}/checkpoints"
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".ckpt")]

    epoch_predictions = []

    for checkpoint in checkpoints:
        epoch_num = int(checkpoint.split(".")[0].split("=")[-1].split(".")[0])
        checkpoint_data = torch.load(os.path.join(checkpoint_dir, checkpoint))
        model.load_state_dict(checkpoint_data["state_dict"], strict=False)
        model.eval()
        predictions = []
        for batch in dataloader_test:
            x, _ = batch
            with torch.no_grad():
                y_hat = model(x)
            predictions.append(y_hat)
        preds_np = torch.cat(predictions).numpy().flatten()
        rows = [Row(epoch=epoch_num, prediction=float(pred)) for pred in preds_np]
        df_epoch = spark.createDataFrame(rows)
        epoch_predictions.append(df_epoch)

    df_all_epochs = epoch_predictions[0]
    for df in epoch_predictions[1:]:
        df_all_epochs = df_all_epochs.union(df)

    display(df_all_epochs.limit(10))
    return df_all_epochs


def plot_and_save_predictions_by_epoch(
    epochs,
    preds_by_epoch,
    x_values_train,
    y_values_train,
    output_values_train,
    x_values_test,
    y_values_test,
    path_plot_dir,
    model_name,
):
    for epoch in epochs:
        test_predictions = preds_by_epoch[epoch]
        plt.figure(figsize=(8, 8), dpi=256)
        plt.scatter(
            x_values_train,
            y_values_train,
            c=output_values_train,
            cmap="twilight",
            alpha=1,
            s=0.5,
            label="Train Data",
        )
        plt.scatter(
            x_values_test,
            y_values_test,
            c=test_predictions,
            cmap="twilight",
            alpha=1,
            s=0.5,
            label="Test Predictions",
        )
        plt.xlabel("x")
        plt.ylabel("y")
        plt.title(f"2D Scatter Plot of x and y (colored by output) - Epoch {epoch}")
        plt.legend()
        plt.grid(True)
        plt.gca().set_aspect("equal", adjustable="box")

        fig = plt.gcf()
        with open(
            f"{path_plot_dir}/{model_name}_Plot_epoch={epoch:03d}.pkl", "wb"
        ) as f:
            pickle.dump(fig, f)

        plt.savefig(
            f"{path_plot_dir}/{model_name}_Plot_epoch={epoch:03d}.png",
            bbox_inches="tight",
            pad_inches=0,
        )
        pd.DataFrame(
            {"x": x_values_test, "y": y_values_test, "prediction": test_predictions}
        ).to_csv(
            f"{path_plot_dir}/{model_name}_Data_epoch={epoch:03d}.csv", index=False
        )
        plt.close()
        del test_predictions
        gc.collect()


def display_prediction_images(path_plot_dir):
    png_files = [f for f in os.listdir(path_plot_dir) if f.endswith(".png")]
    images = [Image.open(os.path.join(path_plot_dir, f)) for f in sorted(png_files)]

    grid_size = int(len(images) ** 0.5) + 1

    fig, axes = plt.subplots(grid_size, grid_size, figsize=(20, 20))
    for ax, img in zip(axes.flatten(), images):
        ax.imshow(img)
        ax.axis("off")

    for ax in axes.flatten()[len(images) :]:
        ax.axis("off")

    plt.subplots_adjust(wspace=0, hspace=0)
    output_path = os.path.join(path_plot_dir, "grid_image.png")
    plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
    plt.show()

In [0]:
for model_type in ["RegressorLinearSingleLayer", "RegressorLinearMultiLayer"]:
    for polar_coordinates in [True, False]:
        model_parameters = {
            "model": globals()[model_type](),
            "iteration": 1,
            "polar_coordinates": polar_coordinates,
            "max_epochs": 20,
        }


        model_name = f"{model_parameters['model'].__class__.__name__}_PolarCoordinates_{model_parameters['polar_coordinates']}_{model_parameters['iteration']:03d}"


        checkpoint_dir = f"/Workspace/Shared/lightning_logs/{model_name}/checkpoints/"
        if os.path.exists(checkpoint_dir):
            shutil.rmtree(checkpoint_dir)

        checkpoint_callback = pl.callbacks.ModelCheckpoint(
            dirpath=f"/Workspace/Shared/lightning_logs/{model_name}/checkpoints/",
            filename=f"{model_name}_{{epoch:03d}}",
            save_top_k=-1,
            save_weights_only=True,
            every_n_epochs=1,
            monitor="val_loss",
            mode="min",
        )


        dataset_train = SparkDataFrameDataset(
            df_train, polar_coordinates=model_parameters["polar_coordinates"]
        )
        dataset_val = SparkDataFrameDataset(
            df_val, polar_coordinates=model_parameters["polar_coordinates"]
        )
        dataset_test = SparkDataFrameDataset(
            df_test, polar_coordinates=model_parameters["polar_coordinates"]
        )

        dataloader_train = DataLoader(dataset_train, batch_size=256, shuffle=True)
        dataloader_val = DataLoader(dataset_val, batch_size=256)
        dataloader_test = DataLoader(dataset_test, batch_size=256)

        path_plot_dir = f"//Workspace/Shared/lightning_logs/{model_name}/plots/"
        os.makedirs(path_plot_dir, exist_ok=True)


        train_model(model_parameters, dataset_train, dataset_val, dataset_test)


        df_all_epochs = get_epoch_predictions(
            model_parameters["model"], model_name, dataloader_test
        )
        epochs = [row["epoch"] for row in df_all_epochs.select("epoch").distinct().collect()]
        epochs = sorted(epochs)

        df_all_epochs = df_all_epochs.orderBy("epoch").toPandas()


        preds_by_epoch = {
            epoch: df_all_epochs[df_all_epochs["epoch"] == epoch]["prediction"].values
            for epoch in epochs
        }


        plot_and_save_predictions_by_epoch(
            epochs,
            preds_by_epoch,
            x_values_train,
            y_values_train,
            output_values_train,
            x_values_test,
            y_values_test,
            path_plot_dir,
            model_name,
        )


        display_prediction_images(path_plot_dir)
        print("=" * 80)