In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import random

## Model loading

In [2]:
class MaskedAutoEncoder(nn.Module):
    def __init__(self, in_dim=9, maskable_dim=4, embed_dim=64, depth=4):
        super().__init__()
        self.in_dim = in_dim
        self.maskable_dim = maskable_dim

        self.mask_token = nn.Parameter(torch.zeros(in_dim))

        layers = []
        for _ in range(depth):
            layers += [
                nn.Linear(embed_dim if layers else in_dim, embed_dim),
                nn.GELU(),
                nn.LayerNorm(embed_dim)
            ]
        self.encoder = nn.Sequential(*layers)

        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, in_dim)
        )

    def forward(self, x, mask):
        x = torch.where(mask, x, self.mask_token)
        z = self.encoder(x)
        return self.decoder(z)

In [3]:
MODEL_PATH = 'models/model.pt'

In [5]:
if os.path.exists(MODEL_PATH):
    print(f"Loading model from {MODEL_PATH}")
    model = MaskedAutoEncoder(in_dim=9, embed_dim=128).cuda()
    model.load_state_dict(torch.load(MODEL_PATH))
    model.eval()
else:
    print(f"MAE model not found!")

Loading model from models/model.pt


## Data loading

In [8]:
df = pd.read_parquet("data/processed/nitrate_00_train_data.parquet")
cols = ["temperature_00", "salinity_00", "oxygen_00", "phosphate_00"]

def encode_geospatial_features(df: pd.DataFrame) -> np.ndarray:
    lat_rad = np.radians(df["lat"].to_numpy())
    lon_rad = np.radians(df["lon"].to_numpy())

    sin_lat = np.sin(lat_rad)
    cos_lat = np.cos(lat_rad)
    sin_lon = np.sin(lon_rad)
    cos_lon = np.cos(lon_rad)

    depth = df["depth"].to_numpy(dtype=np.float32)
    norm_depth = (depth - depth.min()) / (depth.max() - depth.min())

    geo_features = np.stack([sin_lat, cos_lat, sin_lon, cos_lon, norm_depth], axis=1)
    return geo_features.astype(np.float32)

class Scaler:
    def __init__(self, mean: dict[str, float], std: dict[str, float]):
        self.mean = mean
        self.std = std
        self.cols = list(mean.keys())

    @classmethod
    def from_dataframe(cls, df, cols):
        mean = {col: df[col].mean() for col in cols}
        std = {col: df[col].std() for col in cols}
        return cls(mean, std)

    def normalize(self, tensor: torch.Tensor, cols: list[str]) -> torch.Tensor:
        for i, col in enumerate(cols):
            tensor[:, i] = (tensor[:, i] - self.mean[col]) / self.std[col]
        return tensor

    def denormalize(self, tensor: torch.Tensor, cols: list[str]) -> torch.Tensor:
        means = torch.tensor([self.mean[c] for c in cols], dtype=tensor.dtype, device=tensor.device)
        stds  = torch.tensor([self.std[c]  for c in cols], dtype=tensor.dtype, device=tensor.device)
        return tensor * stds + means

    def mae(self, reconstructed: torch.Tensor, ground_truth: torch.Tensor, cols: list[str]) -> float:
        rec_denorm = self.denormalize(reconstructed.clone(), cols)
        gt_denorm = self.denormalize(ground_truth.clone(), cols)
        return torch.abs(rec_denorm - gt_denorm).mean().item()

    def masked_mae(self, reconstructed: torch.Tensor, ground_truth: torch.Tensor, mask: torch.Tensor, cols: list[str]) -> float:
        rec_denorm = self.denormalize(reconstructed.clone(), cols)
        gt_denorm = self.denormalize(ground_truth.clone(), cols)

        abs_error = torch.abs(rec_denorm - gt_denorm)
        masked_error = abs_error * mask

        mae = masked_error.sum() / mask.sum().clamp(min=1.0)
        return mae.item()

df = df.dropna(subset=cols).reset_index(drop=True)
scaler = Scaler.from_dataframe(df, cols)
geo = encode_geospatial_features(df)
x = df[cols].to_numpy(dtype=np.float32)

x_full = np.concatenate([x, geo], axis=1)

X = torch.tensor(x_full)
X = scaler.normalize(X.clone(), cols)

class PredictionDataset(Dataset):
    def __init__(self, X: torch.Tensor, y: torch.Tensor):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

TARGET         = "nitrate_00"
TEST_BBOX      = {                       # Morze Śródziemne
    "lat_min": 30.0, "lat_max": 46.0,
    "lon_min": -6.0, "lon_max": 36.0
}
SEED           = 42
N_JOBS         = -1
SUB_FRAC       = 0.20

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

test_mask = (
    (df["lat"].between(TEST_BBOX["lat_min"], TEST_BBOX["lat_max"])) &
    (df["lon"].between(TEST_BBOX["lon_min"], TEST_BBOX["lon_max"]))
)

X_train_large = X[~test_mask, :].clone()
X_test = X[test_mask, :].clone()

y_train_large = df[~test_mask][TARGET].to_numpy(dtype=np.float32)
y_test = df[test_mask][TARGET].to_numpy(dtype=np.float32)

print(f"X_train_large shape: {X_train_large.shape}")
print(f"y_train_large  shape: {y_train_large.shape}")

idx = np.random.choice(X_train_large.shape[0], int((X_train_large.shape[0])*SUB_FRAC), replace=False)

X_train = torch.Tensor(X_train_large[idx])
y_train = torch.Tensor(y_train_large[idx])

print(f"X_train shape: {X_train.shape}")
print(f"X_test  shape: {X_test.shape}")

print(f"y_train shape: {y_train.shape}")
print(f"y_test  shape: {y_test.shape}")

train_ds = PredictionDataset(X_train, y_train)
test_ds  = PredictionDataset(X_test, y_test)

train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=len(test_ds), shuffle=False)


X_train_large shape: torch.Size([601393, 9])
y_train_large  shape: (601393,)
X_train shape: torch.Size([120278, 9])
X_test  shape: torch.Size([8519, 9])
y_train shape: torch.Size([120278])
y_test  shape: (8519,)


## Training loop

In [47]:
def train(model_, loader_, n_epochs=20, lr=1e-3, weight_decay=1e-4, verbose=True, gradient_clipping=False):
    opt = torch.optim.Adam(model_.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=n_epochs)
    loss_fn = nn.MSELoss()

    for epoch in range(n_epochs):
        model_.train()
        total_loss = 0
        all_y = []
        all_pred = []

        for x, y in loader_:
            x, y = x.cuda(), y.cuda()
            pred = model_(x).flatten()

            loss = loss_fn(pred, y)
            loss.backward()

            if gradient_clipping:
                torch.nn.utils.clip_grad_norm_(model_.parameters(), max_norm=1.0)

            opt.step()
            opt.zero_grad()

            total_loss += loss.item()
            all_y.append(y.detach().cpu())
            all_pred.append(pred.detach().cpu())

        scheduler.step()

        # Concatenate all predictions and targets
        all_y = torch.cat(all_y).numpy()
        all_pred = torch.cat(all_pred).numpy()

        r2 = r2_score(all_y, all_pred)
        avg_loss = total_loss / len(loader_)

        if verbose:
            print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}  R2 = {r2:.4f}")

    return avg_loss, r2

def metrics(name, y, pred, verbose=True):
    mse = mean_squared_error(y, pred)
    rmse = np.sqrt(mse)
    mae  = mean_absolute_error(y, pred)
    r2   = r2_score(y, pred)
    if verbose:
        print(f"{name:<18} RMSE={rmse:.4f} MSE={mse:.4f} MAE={mae:.4f} R2={r2:.4f}")
    return mse, rmse, mae, r2



def evaluate(name, model_, loader_):
    model_.eval()
    for x, y in loader_:
        x, y = x.cuda(), y.cuda()
        pred = model_(x).flatten()
        y = y.detach().cpu().numpy()
        pred = pred.detach().cpu().numpy()
        return metrics(name, y, pred)


def initialize_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

## Sanity check: linear prediction without embeddings

In [58]:
class BaselineMLP(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.predictor = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.predictor(x).squeeze(-1)

for _ in range(10):
    torch.manual_seed(42); torch.cuda.manual_seed_all(42)
    pred_model = BaselineMLP(in_dim=9).cuda()
    _ = train(pred_model, train_loader, n_epochs=13, lr=1e-3, verbose=False)
    _ = evaluate("Baseline", pred_model, test_loader)

Baseline           RMSE=1.7460 MSE=3.0484 MAE=1.4035 R2=0.6876
Baseline           RMSE=1.7460 MSE=3.0484 MAE=1.4035 R2=0.6876
Baseline           RMSE=1.7460 MSE=3.0484 MAE=1.4035 R2=0.6876
Baseline           RMSE=1.7460 MSE=3.0484 MAE=1.4035 R2=0.6876
Baseline           RMSE=1.7460 MSE=3.0484 MAE=1.4035 R2=0.6876
Baseline           RMSE=1.7460 MSE=3.0484 MAE=1.4035 R2=0.6876
Baseline           RMSE=1.7460 MSE=3.0484 MAE=1.4035 R2=0.6876
Baseline           RMSE=1.7460 MSE=3.0484 MAE=1.4035 R2=0.6876
Baseline           RMSE=1.7460 MSE=3.0484 MAE=1.4035 R2=0.6876
Baseline           RMSE=1.7460 MSE=3.0484 MAE=1.4035 R2=0.6876


## Simple linear prediction

In [59]:
class LinearPredictionModel(nn.Module):
    def __init__(self, in_dim, mae):
        super().__init__()
        self.mae = mae
        self.predictor = nn.Linear(in_dim, 1)

        for param in self.mae.parameters():
            param.requires_grad = False

    def forward(self, x):
        enc = self.mae.encoder(x)
        return self.predictor(enc)

torch.manual_seed(42); torch.cuda.manual_seed_all(42)
pred_model = LinearPredictionModel(in_dim=128, mae=model).cuda()
pred_model.predictor.apply(initialize_weights)
_ = train(pred_model, train_loader, n_epochs=30, lr=1e-4)

Epoch 1: loss = 440.3712  R2 = -1.3863
Epoch 2: loss = 402.6417  R2 = -1.1819
Epoch 3: loss = 367.7134  R2 = -0.9926
Epoch 4: loss = 335.4680  R2 = -0.8179
Epoch 5: loss = 305.7920  R2 = -0.6571
Epoch 6: loss = 278.6206  R2 = -0.5098
Epoch 7: loss = 253.8073  R2 = -0.3754
Epoch 8: loss = 231.2943  R2 = -0.2534
Epoch 9: loss = 210.9531  R2 = -0.1431
Epoch 10: loss = 192.6362  R2 = -0.0439
Epoch 11: loss = 176.2463  R2 = 0.0450
Epoch 12: loss = 161.6234  R2 = 0.1242
Epoch 13: loss = 148.6730  R2 = 0.1944
Epoch 14: loss = 137.2514  R2 = 0.2562
Epoch 15: loss = 127.2473  R2 = 0.3105
Epoch 16: loss = 118.5088  R2 = 0.3578
Epoch 17: loss = 110.9615  R2 = 0.3987
Epoch 18: loss = 104.4755  R2 = 0.4338
Epoch 19: loss = 98.9636  R2 = 0.4638
Epoch 20: loss = 94.3029  R2 = 0.4890
Epoch 21: loss = 90.4308  R2 = 0.5100
Epoch 22: loss = 87.2564  R2 = 0.5272
Epoch 23: loss = 84.7093  R2 = 0.5410
Epoch 24: loss = 82.7167  R2 = 0.5518
Epoch 25: loss = 81.2037  R2 = 0.5600
Epoch 26: loss = 80.1112  R2 = 

In [60]:
_ = evaluate("Linear prediction", pred_model, test_loader)

Linear prediction  RMSE=6.6740 MSE=44.5425 MAE=6.0153 R2=-3.5646


## Multi-layer linear prediction

In [61]:
class MultiLinearPredictionModel(nn.Module):
    def __init__(self, in_dim, mae):
        super().__init__()
        self.mae = mae
        self.predictor = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

        for param in self.mae.parameters():
            param.requires_grad = False

    def forward(self, x):
        enc = self.mae.encoder(x)
        return self.predictor(enc)

torch.manual_seed(42); torch.cuda.manual_seed_all(42)
pred_model = MultiLinearPredictionModel(in_dim=128, mae=model).cuda()
pred_model.predictor.apply(initialize_weights)
_ = train(pred_model, train_loader, n_epochs=20, lr=1e-4)

Epoch 1: loss = 344.0143  R2 = -0.8644
Epoch 2: loss = 96.0309  R2 = 0.4795
Epoch 3: loss = 20.0958  R2 = 0.8911
Epoch 4: loss = 11.9658  R2 = 0.9352
Epoch 5: loss = 9.8705  R2 = 0.9465
Epoch 6: loss = 8.5703  R2 = 0.9536
Epoch 7: loss = 7.7079  R2 = 0.9582
Epoch 8: loss = 7.1210  R2 = 0.9614
Epoch 9: loss = 6.7034  R2 = 0.9637
Epoch 10: loss = 6.3939  R2 = 0.9654
Epoch 11: loss = 6.1655  R2 = 0.9666
Epoch 12: loss = 5.9914  R2 = 0.9675
Epoch 13: loss = 5.8575  R2 = 0.9683
Epoch 14: loss = 5.7549  R2 = 0.9688
Epoch 15: loss = 5.6777  R2 = 0.9692
Epoch 16: loss = 5.6212  R2 = 0.9695
Epoch 17: loss = 5.5830  R2 = 0.9697
Epoch 18: loss = 5.5587  R2 = 0.9699
Epoch 19: loss = 5.5469  R2 = 0.9699
Epoch 20: loss = 5.5412  R2 = 0.9700


In [62]:
_ = evaluate("Multi-layer linear prediction", pred_model, test_loader)

Multi-layer linear prediction RMSE=4.8045 MSE=23.0828 MAE=4.0504 R2=-1.3655


## Linear+Dropout prediction

In [63]:
class DropoutLinearPredictionModel(nn.Module):
    def __init__(self, in_dim, mae, dropout=0.3):
        super().__init__()
        self.mae = mae
        self.predictor = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout),           # Reduced from 0.5 to avoid underfitting
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(dropout / 2),       # Add dropout between deeper layers
            nn.Linear(64, 1)
        )

        for param in self.mae.parameters():
            param.requires_grad = False

    def forward(self, x):
        enc = self.mae.encoder(x)
        return self.predictor(enc)

torch.manual_seed(42); torch.cuda.manual_seed_all(42)
pred_model = DropoutLinearPredictionModel(in_dim=128, mae=model).cuda()
pred_model.predictor.apply(initialize_weights)
_ = train(pred_model, train_loader, n_epochs=20, lr=1e-4, gradient_clipping=True)

Epoch 1: loss = 98.4124  R2 = 0.4666
Epoch 2: loss = 9.8125  R2 = 0.9468
Epoch 3: loss = 8.8859  R2 = 0.9518
Epoch 4: loss = 8.5097  R2 = 0.9539
Epoch 5: loss = 8.2663  R2 = 0.9552
Epoch 6: loss = 8.0497  R2 = 0.9564
Epoch 7: loss = 7.8850  R2 = 0.9573
Epoch 8: loss = 7.7623  R2 = 0.9579
Epoch 9: loss = 7.6949  R2 = 0.9583
Epoch 10: loss = 7.6750  R2 = 0.9584
Epoch 11: loss = 7.5781  R2 = 0.9589
Epoch 12: loss = 7.5043  R2 = 0.9593
Epoch 13: loss = 7.4382  R2 = 0.9597
Epoch 14: loss = 7.4180  R2 = 0.9598
Epoch 15: loss = 7.3671  R2 = 0.9601
Epoch 16: loss = 7.4033  R2 = 0.9599
Epoch 17: loss = 7.3678  R2 = 0.9601
Epoch 18: loss = 7.4227  R2 = 0.9598
Epoch 19: loss = 7.3792  R2 = 0.9600
Epoch 20: loss = 7.3601  R2 = 0.9601


In [64]:
_ = evaluate("Dropout+linear prediction", pred_model, test_loader)

Dropout+linear prediction RMSE=2.4771 MSE=6.1361 MAE=1.9888 R2=0.3712


## Deeper Droput prediction

In [65]:
class DeepDropoutLinearPredictionModel(nn.Module):
    def __init__(self, in_dim, mae, dropout=0.3):
        super().__init__()
        self.mae = mae

        # Freeze MAE if needed
        for param in self.mae.parameters():
            param.requires_grad = False

        self.predictor = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, 256),
            nn.SiLU(),
            nn.Dropout(dropout),

            nn.LayerNorm(256),
            nn.Linear(256, 128),
            nn.SiLU(),
            nn.Dropout(dropout),

            nn.LayerNorm(128),
            nn.Linear(128, 64),
            nn.SiLU(),
            nn.Dropout(dropout / 2),

            nn.Linear(64, 32),
            nn.SiLU(),
            nn.Dropout(dropout / 2),

            nn.Linear(32, 1)
        )

    def forward(self, x):
        enc = self.mae.encoder(x)  # (B, D)
        return self.predictor(enc).squeeze(-1)  # (B,)

torch.manual_seed(42); torch.cuda.manual_seed_all(42)
pred_model = DeepDropoutLinearPredictionModel(in_dim=128, mae=model).cuda()
pred_model.predictor.apply(initialize_weights)
_ = train(pred_model, train_loader, n_epochs=20, lr=1e-4, weight_decay=5e-4, gradient_clipping=True)

Epoch 1: loss = 166.2842  R2 = 0.0987
Epoch 2: loss = 20.2856  R2 = 0.8901
Epoch 3: loss = 17.0555  R2 = 0.9076
Epoch 4: loss = 15.5918  R2 = 0.9155
Epoch 5: loss = 14.8171  R2 = 0.9197
Epoch 6: loss = 14.1695  R2 = 0.9232
Epoch 7: loss = 13.6580  R2 = 0.9260
Epoch 8: loss = 13.3592  R2 = 0.9276
Epoch 9: loss = 13.0840  R2 = 0.9291
Epoch 10: loss = 12.8278  R2 = 0.9305
Epoch 11: loss = 12.6100  R2 = 0.9317
Epoch 12: loss = 12.3367  R2 = 0.9331
Epoch 13: loss = 12.2020  R2 = 0.9339
Epoch 14: loss = 12.0326  R2 = 0.9348
Epoch 15: loss = 12.0572  R2 = 0.9347
Epoch 16: loss = 12.0316  R2 = 0.9348
Epoch 17: loss = 11.9318  R2 = 0.9353
Epoch 18: loss = 11.9596  R2 = 0.9352
Epoch 19: loss = 11.9422  R2 = 0.9353
Epoch 20: loss = 11.9955  R2 = 0.9350


In [66]:
_ = evaluate("Deep dropout+linear prediction", pred_model, test_loader)

Deep dropout+linear prediction RMSE=3.5695 MSE=12.7410 MAE=2.6521 R2=-0.3057


## GELU prediction

In [67]:
class GELUPredictionModel(nn.Module):
    def __init__(self, in_dim, mae):
        super().__init__()
        self.mae = mae
        self.predictor = nn.Sequential(
            nn.LayerNorm(128),
            nn.Linear(128, 256),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )

        for param in self.mae.parameters():
            param.requires_grad = False

    def forward(self, x):
        enc = self.mae.encoder(x)
        return self.predictor(enc)

torch.manual_seed(42); torch.cuda.manual_seed_all(42)
pred_model = GELUPredictionModel(in_dim=128, mae=model).cuda()
pred_model.predictor.apply(initialize_weights)
_ = train(pred_model, train_loader, n_epochs=20, lr=1e-4)

Epoch 1: loss = 96.7575  R2 = 0.4755
Epoch 2: loss = 11.1466  R2 = 0.9396
Epoch 3: loss = 9.8448  R2 = 0.9467
Epoch 4: loss = 9.0926  R2 = 0.9507
Epoch 5: loss = 8.5818  R2 = 0.9535
Epoch 6: loss = 8.1666  R2 = 0.9557
Epoch 7: loss = 7.8592  R2 = 0.9574
Epoch 8: loss = 7.6489  R2 = 0.9586
Epoch 9: loss = 7.4292  R2 = 0.9597
Epoch 10: loss = 7.3237  R2 = 0.9603
Epoch 11: loss = 7.1841  R2 = 0.9611
Epoch 12: loss = 7.0667  R2 = 0.9617
Epoch 13: loss = 7.0068  R2 = 0.9620
Epoch 14: loss = 6.9394  R2 = 0.9624
Epoch 15: loss = 6.8988  R2 = 0.9626
Epoch 16: loss = 6.8430  R2 = 0.9629
Epoch 17: loss = 6.8424  R2 = 0.9629
Epoch 18: loss = 6.8251  R2 = 0.9630
Epoch 19: loss = 6.8238  R2 = 0.9630
Epoch 20: loss = 6.7839  R2 = 0.9632


In [68]:
_ = evaluate("GELU prediction", pred_model, test_loader)

GELU prediction    RMSE=3.1800 MSE=10.1124 MAE=2.3185 R2=-0.0363


## Residual Block prediction

In [69]:
class ResidualBlock(nn.Module):
    def __init__(self, dim, dropout=0.2):
        super().__init__()
        self.block = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
        )
        self.ln = nn.LayerNorm(dim)

    def forward(self, x):
        return self.ln(x + self.block(x))

class ResidualPredictionModel(nn.Module):
    def __init__(self, in_dim, mae):
        super().__init__()
        self.mae = mae

        for param in self.mae.parameters():
            param.requires_grad = False

        self.input_proj = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, 256),
            nn.SiLU()
        )

        self.residual_blocks = nn.Sequential(
            ResidualBlock(256, dropout=0.3),
            ResidualBlock(256, dropout=0.2)
        )

        self.output_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.SiLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        enc = self.mae.encoder(x)  # (B, D)
        x = self.input_proj(enc)   # (B, 256)
        x = self.residual_blocks(x)
        return self.output_head(x)

torch.manual_seed(42); torch.cuda.manual_seed_all(42)
pred_model = ResidualPredictionModel(in_dim=128, mae=model).cuda()
pred_model.input_proj.apply(initialize_weights)
pred_model.output_head.apply(initialize_weights)
pred_model.residual_blocks.apply(initialize_weights)
_ = train(pred_model, train_loader, n_epochs=20, lr=1e-4)

Epoch 1: loss = 39.6192  R2 = 0.7852
Epoch 2: loss = 7.2985  R2 = 0.9604
Epoch 3: loss = 6.4085  R2 = 0.9653
Epoch 4: loss = 5.9576  R2 = 0.9677
Epoch 5: loss = 5.6980  R2 = 0.9691
Epoch 6: loss = 5.5381  R2 = 0.9700
Epoch 7: loss = 5.3896  R2 = 0.9708
Epoch 8: loss = 5.2798  R2 = 0.9714
Epoch 9: loss = 5.1755  R2 = 0.9720
Epoch 10: loss = 5.1079  R2 = 0.9723
Epoch 11: loss = 5.0338  R2 = 0.9727
Epoch 12: loss = 4.9863  R2 = 0.9730
Epoch 13: loss = 4.9003  R2 = 0.9734
Epoch 14: loss = 4.8753  R2 = 0.9736
Epoch 15: loss = 4.8479  R2 = 0.9737
Epoch 16: loss = 4.7913  R2 = 0.9740
Epoch 17: loss = 4.7931  R2 = 0.9740
Epoch 18: loss = 4.7590  R2 = 0.9742
Epoch 19: loss = 4.7634  R2 = 0.9742
Epoch 20: loss = 4.7702  R2 = 0.9742


In [70]:
_ = evaluate("Residual prediction", pred_model, test_loader)

Residual prediction RMSE=2.5244 MSE=6.3724 MAE=2.0607 R2=0.3470


In [71]:
torch.manual_seed(42); torch.cuda.manual_seed_all(42)
pred_model = ResidualPredictionModel(in_dim=128, mae=model).cuda()
pred_model.input_proj.apply(initialize_weights)
pred_model.output_head.apply(initialize_weights)
pred_model.residual_blocks.apply(initialize_weights)
_ = train(pred_model, train_loader, n_epochs=50, lr=1e-4)
_ = evaluate("Residual prediction", pred_model, test_loader)

Epoch 1: loss = 39.6192  R2 = 0.7852
Epoch 2: loss = 7.2962  R2 = 0.9605
Epoch 3: loss = 6.4029  R2 = 0.9653
Epoch 4: loss = 5.9476  R2 = 0.9678
Epoch 5: loss = 5.6837  R2 = 0.9692
Epoch 6: loss = 5.5210  R2 = 0.9701
Epoch 7: loss = 5.3646  R2 = 0.9709
Epoch 8: loss = 5.2566  R2 = 0.9715
Epoch 9: loss = 5.1408  R2 = 0.9721
Epoch 10: loss = 5.0510  R2 = 0.9726
Epoch 11: loss = 4.9626  R2 = 0.9731
Epoch 12: loss = 4.9056  R2 = 0.9734
Epoch 13: loss = 4.8054  R2 = 0.9740
Epoch 14: loss = 4.7683  R2 = 0.9742
Epoch 15: loss = 4.7088  R2 = 0.9745
Epoch 16: loss = 4.6245  R2 = 0.9749
Epoch 17: loss = 4.6099  R2 = 0.9750
Epoch 18: loss = 4.5394  R2 = 0.9754
Epoch 19: loss = 4.5141  R2 = 0.9755
Epoch 20: loss = 4.4966  R2 = 0.9756
Epoch 21: loss = 4.4476  R2 = 0.9759
Epoch 22: loss = 4.3945  R2 = 0.9762
Epoch 23: loss = 4.3817  R2 = 0.9763
Epoch 24: loss = 4.3543  R2 = 0.9764
Epoch 25: loss = 4.3112  R2 = 0.9766
Epoch 26: loss = 4.3084  R2 = 0.9767
Epoch 27: loss = 4.2662  R2 = 0.9769
Epoch 28: