In [1]:
import warnings
from tqdm import tqdm

warnings.filterwarnings("ignore")

In [2]:
import pandas as pd
import numpy as np

df = pd.read_csv("./out.csv")[
    [
        "V0_LAT",
        "V1_LAT",
        "V2_LAT",
        "V3_LAT",
        "V0_LON",
        "V1_LON",
        "V2_LON",
        "V3_LON",
        "Fe_wt",
        "Ti_wt",
        "Ca_wt",
        "Si_wt",
        "Al_wt",
        "Mg_wt",
        "Na_wt",
        "O_wt",
    ]
]

wt_cols = [col for col in df.columns if "_wt" == col[-3:]]
df[wt_cols] /= 100

df["lat"] = (df["V0_LAT"] + df["V1_LAT"] + df["V2_LAT"] + df["V3_LAT"]) / 4
df["lon"] = (df["V0_LON"] + df["V1_LON"] + df["V2_LON"] + df["V3_LON"]) / 4
df["dlat"] = np.max(
    (
        np.abs(df["V0_LAT"] - df["V1_LAT"]),
        np.abs(df["V0_LAT"] - df["V2_LAT"]),
        np.abs(df["V3_LAT"] - df["V1_LAT"]),
        np.abs(df["V3_LAT"] - df["V2_LAT"]),
    )
)
df["dlon"] = np.max(
    (
        np.abs(df["V0_LON"] - df["V2_LON"]),
        np.abs(df["V0_LON"] - df["V3_LON"]),
        np.abs(df["V1_LON"] - df["V2_LON"]),
        np.abs(df["V1_LON"] - df["V3_LON"]),
    )
)

In [3]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch import optim

In [4]:
class PositionalDataset(Dataset):
    def __init__(
        self, df, nstd=1 / 2.6, renorm=True, device="cpu"
    ):  # 2.6: 99%, 3.3: 99.9%, 3.9: 99.99%, 4.5: 99.999%
        self.df = df
        self.nstd = nstd
        self.renorm = renorm
        self.device = device
        self.wt_cols = [
            "Mg_wt",
            "Al_wt",
            "Si_wt",
            "Ca_wt",
            "Ti_wt",
            "Fe_wt",
            # "Na_wt",
            "O_wt",
        ]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        lat = self.df["lat"].iloc[idx]
        lon = self.df["lon"].iloc[idx]
        dlat = self.df["dlat"].iloc[idx]
        dlon = self.df["dlon"].iloc[idx]

        if self.renorm:
            self.df[self.wt_cols].iloc[idx] /= self.df[self.wt_cols].iloc[idx].sum()

        return torch.normal(
            torch.tensor((lat, lon), dtype=torch.float32),
            self.nstd * torch.tensor((dlat, dlon), dtype=torch.float32),
        ).to(self.device), torch.tensor(
            self.df[self.wt_cols].iloc[idx].values,
            dtype=torch.float32,
        ).to(
            self.device
        )

In [5]:
class PositionalEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, 512)
        self.fc4 = nn.Linear(512, 7)
        self.ac1 = nn.LeakyReLU()
        self.norm1 = nn.BatchNorm1d(1024)
        self.norm2 = nn.BatchNorm1d(512)

    def forward(self, x):
        x = self.ac1(self.fc1(x))
        x = self.norm1(self.ac1(self.fc2(x)))
        x = self.norm2(self.ac1(self.fc3(x)))
        x = self.ac1(self.fc4(x))
        return x, F.softmax(x, dim=1)[:, :6]


k = 10
models = [PositionalEncoder() for _ in range(k)]
for i, model in enumerate(models):
    model.load_state_dict(
        torch.load(
            f"./ckpts/20241206170035/model_{i}.pt",
            weights_only=False,
            map_location=torch.device("cpu"),
        )
    )
    model.to("mps")
    model.eval()

In [6]:
valid_dl = DataLoader(
    PositionalDataset(df, renorm=False, device="mps"),
    batch_size=1024,
    shuffle=True,
    num_workers=0,
)

In [None]:
atwts = pd.read_csv("./data_constants/atomicweight.txt", sep="\t", header=None)
atwts.columns = ["atno", "sym", "atwt"]
atwts.set_index("sym", inplace=True)
atwts.atwt

elements = ["mg", "al", "si", "ca", "ti", "fe", "o"]

ele_wts = atwts.loc[elements].atwt

oxides = ["mgo", "al2o3", "sio2", "cao", "tio2", "feo"]

owts = {
    "mgo": ele_wts["o"] / (ele_wts["mg"] + ele_wts["o"]),
    "al2o3": ele_wts["o"] * 3 / (2 * ele_wts["al"] + 3 * ele_wts["o"]),
    "sio2": ele_wts["o"] * 2 / (ele_wts["si"] + 2 * ele_wts["o"]),
    "cao": ele_wts["o"] / (ele_wts["ca"] + ele_wts["o"]),
    "tio2": ele_wts["o"] * 2 / (ele_wts["ti"] + 2 * ele_wts["o"]),
    "feo": ele_wts["o"] / (ele_wts["fe"] + ele_wts["o"]),
}

orelwt = list(owts.values())
orelwt = torch.tensor(orelwt, dtype=torch.float32).to("mps")


def position_to_elements(input):
    outputs = torch.stack([model(input)[1] for model in models])
    oxides = torch.mean(outputs, axis=0)
    ox_contrib = oxides * orelwt
    elems = oxides - ox_contrib
    oxygen = torch.sum(ox_contrib, axis=1)

    return torch.cat((elems, oxygen.unsqueeze(1)), dim=1)


def position_to_elements_model(model, input):
    oxides = model(input)[1]
    ox_contrib = oxides * orelwt
    elems = oxides - ox_contrib
    oxygen = torch.sum(ox_contrib, axis=1)

    return torch.cat((elems, oxygen.unsqueeze(1)), dim=1)

In [8]:
# k means cross validation
k = 10
shuffled_df = df.sample(frac=1, random_state=42).reset_index(drop=True)
dl = len(shuffled_df) // k

test_dfs = [shuffled_df[i * dl : (i + 1) * dl] for i in range(k - 1)]
test_dfs.append(shuffled_df[(k - 1) * dl :])


train_dfs = [shuffled_df.drop(test_df.index) for test_df in test_dfs]
train_dls = [
    DataLoader(PositionalDataset(train_df), batch_size=1024, shuffle=True)
    for train_df in train_dfs
]
test_dls = [
    DataLoader(PositionalDataset(test_df), batch_size=1024, shuffle=True)
    for test_df in test_dfs
]

optimizers = [optim.Adam(model.parameters(), lr=1e-3) for model in models]

In [None]:
from torch.utils.tensorboard import SummaryWriter
import numpy as np

writer = SummaryWriter(log_dir="./log")

epochs = 1000
mse = [[[] for _ in range(k)] for _ in range(epochs)]
msev = [[[] for _ in range(k)] for _ in range(epochs)]
cnts = [[0 for _ in range(k)] for _ in range(epochs)]
cntsv = [[0 for _ in range(k)] for _ in range(epochs)]

for epoch in range(epochs):
    print("epoch", epoch)

    train_mse_list = []
    test_mse_list = []

    for i in range(k):
        model = models[i]
        optimizer = optimizers[i]

        # Training phase
        model.train()
        for input, target in train_dls[i]:
            input = input.to("mps")
            target = target.to("mps")
            output = position_to_elements_model(model, input)
            mse_loss = F.mse_loss(output, target)
            # kld_loss = F.kl_div(output, target, reduction="batchmean")
            loss = mse_loss  # + 0.1 * kld_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            cnts[epoch][i] += input.shape[0]
            mse[epoch][i].append(mse_loss.item() * input.shape[0])

        # Record training MSE for this model
        train_mse = sum(mse[epoch][i]) / cnts[epoch][i]
        train_mse_list.append(train_mse)
        writer.add_scalar(f"Model_{i}/Train_MSE", train_mse, epoch)

        # Evaluation phase
        model.eval()
        with torch.no_grad():
            for input, target in test_dls[i]:
                input = input.to("mps")
                target = target.to("mps")
                output = position_to_elements_model(model, input)
                mse_loss = F.mse_loss(output, target)
                msev[epoch][i].append(mse_loss.item() * input.shape[0])
                cntsv[epoch][i] += input.shape[0]

        # Record testing MSE for this model
        test_mse = sum(msev[epoch][i]) / cntsv[epoch][i]
        test_mse_list.append(test_mse)
        writer.add_scalar(f"Model_{i}/Test_MSE", test_mse, epoch)

    # Average MSE over all models
    avg_train_mse = sum(train_mse_list) / k
    avg_test_mse = sum(test_mse_list) / k

    # Statistical Error
    statistical_error = np.sqrt(avg_test_mse) / k

    # Log average metrics to TensorBoard
    writer.add_scalar("Average/Train_MSE", avg_train_mse, epoch)
    writer.add_scalar("Average/Test_MSE", avg_test_mse, epoch)
    writer.add_scalar("Average/Statistical_Error", statistical_error, epoch)

    # Print metrics to console
    print("Average Training MSE:", avg_train_mse)
    print("Average Testing MSE:", avg_test_mse)
    print("Statistical Error +/-", statistical_error)

# Close TensorBoard writer
writer.close()

epoch 0
Average Training MSE: 0.0029560335500894187
Average Testing MSE: 0.0030233065597712995
Statistical Error +/- 0.0054984602933651336
epoch 1
Average Training MSE: 0.0029554269506172326
Average Testing MSE: 0.0029917686711996793
Statistical Error +/- 0.005469706272917842
epoch 2
Average Training MSE: 0.0029519658214172797
Average Testing MSE: 0.002989768725819886
Statistical Error +/- 0.005467877765477101
epoch 3
Average Training MSE: 0.002946641375438775
Average Testing MSE: 0.0029980957740917804
Statistical Error +/- 0.005475486986644914
epoch 4
Average Training MSE: 0.002945443377062151
Average Testing MSE: 0.0029638934414833785
Statistical Error +/- 0.005444165171523894
epoch 5
Average Training MSE: 0.0029435680130173907
Average Testing MSE: 0.0029720784863457085
Statistical Error +/- 0.005451677252319426
epoch 6
Average Training MSE: 0.002948948370097722
Average Testing MSE: 0.002973707253113389
Statistical Error +/- 0.005453170869423944
epoch 7
Average Training MSE: 0.002945