This Notebook provides a minimal example for using LFP to train a simple LeNet on MNIST.

For more complex examples, refer to the experiment notebooks in ./nbs

### Imports

In [1]:
import os
import joblib
import random

import numpy as np
import torch
import torch.nn as tnn
import torcheval.metrics
import torchvision.datasets as tvisiondata
import torchvision.transforms as T
from tqdm import tqdm

from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

from lfprop.propagation import (
    propagator_lxt as propagator,
)  # LFP propagator. Alternatively, use propagator_zennit
from lfprop.rewards import reward_functions as rewards  # Reward Functions
from lfprop.rewards import rewards as loss_fns

  from .autonotebook import tqdm as notebook_tqdm


### Parameters

In [2]:
model_name = "mlp" #lenet
optimizer_name = "sgd"
savepath = f"/media/lweber/f3ed2aae-a7bf-4a55-b50d-ea8fb534f1f52/reward-backprop/resubmission-1-experiments/test-regression/{model_name}-{optimizer_name}"
os.makedirs(savepath, exist_ok=True)

input_size = 8
lr = 0.01
batch_size = 128
epochs = 100
momentum = 0.9

seed = 0

def set_random_seeds(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load Dataset

In [3]:
class PrepareData(torch.utils.data.Dataset):

    def __init__(self, X, y, scale_X=True):
        if not torch.is_tensor(X):
            if scale_X:
                X = StandardScaler().fit_transform(X)
                self.X = torch.from_numpy(X)
        if not torch.is_tensor(y):
            self.y = torch.from_numpy(y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx].type(torch.float32), self.y[idx].type(torch.float32)

X, y = fetch_california_housing(return_X_y=True)

# create train and test indices
train, test = train_test_split(list(range(X.shape[0])), test_size=.3)

ds = PrepareData(X, y=y, scale_X=True)

training_loader = torch.utils.data.DataLoader(ds, batch_size=batch_size,
                       sampler=torch.utils.data.SubsetRandomSampler(train))
test_loader = torch.utils.data.DataLoader(ds, batch_size=batch_size,
                      sampler=torch.utils.data.SubsetRandomSampler(test))

### Load Model

In [4]:
class RegressionModel(torch.nn.Module):

    def __init__(self, input_size, activation=tnn.ReLU):
        super(RegressionModel, self).__init__()
        self.features = tnn.Sequential(
            torch.nn.Linear(in_features=input_size, out_features=256),
            activation(),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(in_features=256, out_features=128),
            activation(),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(in_features=128, out_features=1)
        )

    def forward(self, X):

        return self.features(X)

model = RegressionModel(input_size=input_size)
model.to(device)

RegressionModel(
  (features): Sequential(
    (0): Linear(in_features=8, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=128, out_features=1, bias=True)
  )
)

### Set Up LFP

In [5]:
# Initialize the LFP Composite (cf. "composites" in zennit or lxt).
# This call is the same whether the lxt or zennit backend is used (propagator_lxt and propagator_zennit).
# Currently, only LFP-Epsilon is implemented. More composites may be added in the future.
propagation_composite = propagator.LFPEpsilonComposite()
#propagation_composite = propagator.LFPHebbianEpsilonComposite(use_oja=True)
#propagation_composite = propagator.LFPGammaComposite(gamma=0.0)

# Initialize the Reward Function.
# Here we use the Reward Function suggested in the LFP-Paper, but check out other reward functions in ./lfp/rewards/reward_functions.py
class RegressionReward:
    def __init__(self, device, **kwargs):
        """
        Computes regression reward
        """
        self.device = device

    def __call__(self, logits, labels):
        """
        Computation
        :param logits:
        :param labels:
        :return:
        """

        # Compute reward
        reward = (labels.view(logits.shape) - logits)**3 * logits.sign()
        return reward
    
reward_func = RegressionReward(device)

# LFP writes its updates into the .grad attribute of the model parameters, and can thus utilize standard torch optimizers
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)

### Set Up Simple Evaluation using torcheval

In [6]:
def eval_model(loader):
    """
    Evaluates the model on a single dataset
    """
    eval_metrics = {
        "reward": torcheval.metrics.Mean(device=device),
        "mse": torcheval.metrics.MeanSquaredError(device=device),
    }

    model.eval()

    # Iterate over Data Loader
    for index, (inputs, labels) in enumerate(loader):
        inputs = inputs.to(device)
        labels = torch.tensor(labels).to(device)

        with torch.no_grad():
            # Get model predictions
            outputs = model(inputs)

        with torch.set_grad_enabled(True):
            # Get rewards
            reward = reward_func(outputs, labels)

        for k, v in eval_metrics.items():
            if k == "reward":
                eval_metrics[k].update(reward)
            else:
                eval_metrics[k].update(outputs.view(labels.shape), labels)

    return_dict = {m: metric.compute().detach().cpu().numpy() for m, metric in eval_metrics.items()}

    # Return evaluation
    return return_dict

### Training Loop

In [7]:
def lfp_step(inputs, labels):
    """
    Performs a single training step using LFP. This is quite similar to a standard gradient descent training loop.
    """
    # Set Model to training mode
    model.train()

    with torch.enable_grad():
        # Zero Optimizer
        optimizer.zero_grad()

        # This applies LFP Hooks/Functions (which depends on whether lxt or zennit backend is used)
        with propagation_composite.context(model) as modified:
            inputs = inputs.detach().requires_grad_(True)
            outputs = modified(inputs)

            # Calculate reward
            # Do like this to avoid tensors being kept in memory
            reward = torch.from_numpy(reward_func(outputs, labels).detach().cpu().numpy()).to(device)

            # Calculate LFP and write into .feedback attribute of parameters
            torch.autograd.grad((outputs,), (inputs,), grad_outputs=(reward,), retain_graph=False)[0]

            # Write LFP Values into .grad attributes. Note the negative sign: LFP requires maximization instead of minimization like gradient descent
            for name, param in model.named_parameters():
                param.grad = -param.feedback

            # Update Clipping. Training may become unstable otherwise, especially in small models with large learning rates.
            # In larger models (e.g., VGG, ResNet), where smaller learning rates are generally utilized, not clipping updates may result in better performance.
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0, 2.0)

            # Optimization step
            optimizer.step()

    # Set Model back to eval mode
    model.eval()


# Training Loop
result = {
    "mse_train": [],
    "mse_test": [],
    "reward_train": [],
    "reward_test": [],
}
for epoch in range(epochs):
    with tqdm(total=len(training_loader)) as pbar:
        # Iterate over Data Loader
        for index, (inputs, labels) in enumerate(training_loader):
            inputs = inputs.to(device)
            labels = torch.tensor(labels).to(device)

            # Perform Update Step
            lfp_step(inputs, labels)

            # Update Progress Bar
            pbar.update(1)

    # Evaluate and print performance after every epoch
    eval_stats_train = eval_model(training_loader)
    eval_stats_test = eval_model(test_loader)
    print(
        "Epoch {}/{}: (Train Reward) {:.2f}; (Train MSE) {:.2f}; (Val Reward) {:.2f}; (Val MSE) {:.2f}".format(
            epoch + 1,
            epochs,
            float(np.mean(eval_stats_train["reward"])),
            float(eval_stats_train["mse"]),
            float(np.mean(eval_stats_test["reward"])),
            float(eval_stats_test["mse"]),
        )
    )

    result["mse_train"].append(float(np.mean(eval_stats_train["mse"])))
    result["mse_test"].append(float(np.mean(eval_stats_test["mse"])))
    result["reward_train"].append(float(np.mean(eval_stats_train["reward"])))
    result["reward_test"].append(float(np.mean(eval_stats_test["reward"])))

  labels = torch.tensor(labels).to(device)
  warn(
100%|██████████| 113/113 [00:01<00:00, 72.84it/s]
  labels = torch.tensor(labels).to(device)


Epoch 1/100: (Train Reward) -0.01; (Train MSE) 0.65; (Val Reward) 0.04; (Val MSE) 0.63


100%|██████████| 113/113 [00:00<00:00, 119.35it/s]


Epoch 2/100: (Train Reward) 0.09; (Train MSE) 0.61; (Val Reward) 0.10; (Val MSE) 0.60


100%|██████████| 113/113 [00:00<00:00, 118.85it/s]


Epoch 3/100: (Train Reward) -0.09; (Train MSE) 0.59; (Val Reward) -0.10; (Val MSE) 0.58


100%|██████████| 113/113 [00:00<00:00, 118.79it/s]


Epoch 4/100: (Train Reward) -0.02; (Train MSE) 0.51; (Val Reward) -0.07; (Val MSE) 0.51


100%|██████████| 113/113 [00:01<00:00, 104.34it/s]


Epoch 5/100: (Train Reward) 0.12; (Train MSE) 0.49; (Val Reward) 0.10; (Val MSE) 0.49


100%|██████████| 113/113 [00:00<00:00, 120.88it/s]


Epoch 6/100: (Train Reward) 0.03; (Train MSE) 0.44; (Val Reward) 0.05; (Val MSE) 0.44


100%|██████████| 113/113 [00:00<00:00, 114.86it/s]


Epoch 7/100: (Train Reward) 0.01; (Train MSE) 0.43; (Val Reward) 0.01; (Val MSE) 0.44


100%|██████████| 113/113 [00:00<00:00, 114.72it/s]


Epoch 8/100: (Train Reward) 0.05; (Train MSE) 0.42; (Val Reward) 0.03; (Val MSE) 0.43


100%|██████████| 113/113 [00:00<00:00, 121.13it/s]


Epoch 9/100: (Train Reward) 0.02; (Train MSE) 0.44; (Val Reward) 0.01; (Val MSE) 0.44


100%|██████████| 113/113 [00:00<00:00, 119.96it/s]


Epoch 10/100: (Train Reward) 0.40; (Train MSE) 0.40; (Val Reward) 0.39; (Val MSE) 0.40


100%|██████████| 113/113 [00:00<00:00, 120.22it/s]


Epoch 11/100: (Train Reward) 0.14; (Train MSE) 0.37; (Val Reward) 0.13; (Val MSE) 0.38


100%|██████████| 113/113 [00:00<00:00, 119.90it/s]


Epoch 12/100: (Train Reward) 0.06; (Train MSE) 0.39; (Val Reward) 0.05; (Val MSE) 0.40


100%|██████████| 113/113 [00:01<00:00, 112.98it/s]


Epoch 13/100: (Train Reward) 0.05; (Train MSE) 0.41; (Val Reward) 0.03; (Val MSE) 0.41


100%|██████████| 113/113 [00:01<00:00, 95.91it/s] 


Epoch 14/100: (Train Reward) 0.19; (Train MSE) 0.34; (Val Reward) 0.18; (Val MSE) 0.35


100%|██████████| 113/113 [00:01<00:00, 112.43it/s]


Epoch 15/100: (Train Reward) 0.10; (Train MSE) 0.36; (Val Reward) 0.08; (Val MSE) 0.37


100%|██████████| 113/113 [00:00<00:00, 115.62it/s]


Epoch 16/100: (Train Reward) -0.13; (Train MSE) 0.43; (Val Reward) -0.14; (Val MSE) 0.44


100%|██████████| 113/113 [00:00<00:00, 134.84it/s]


Epoch 17/100: (Train Reward) 0.28; (Train MSE) 0.35; (Val Reward) 0.27; (Val MSE) 0.36


100%|██████████| 113/113 [00:01<00:00, 100.62it/s]


Epoch 18/100: (Train Reward) 0.17; (Train MSE) 0.36; (Val Reward) 0.15; (Val MSE) 0.36


100%|██████████| 113/113 [00:01<00:00, 112.05it/s]


Epoch 19/100: (Train Reward) -0.01; (Train MSE) 0.39; (Val Reward) -0.06; (Val MSE) 0.40


100%|██████████| 113/113 [00:01<00:00, 111.25it/s]


Epoch 20/100: (Train Reward) 0.13; (Train MSE) 0.38; (Val Reward) 0.11; (Val MSE) 0.38


100%|██████████| 113/113 [00:00<00:00, 117.03it/s]


Epoch 21/100: (Train Reward) -0.07; (Train MSE) 0.42; (Val Reward) -0.08; (Val MSE) 0.43


100%|██████████| 113/113 [00:00<00:00, 120.83it/s]


Epoch 22/100: (Train Reward) 0.39; (Train MSE) 0.37; (Val Reward) 0.38; (Val MSE) 0.37


100%|██████████| 113/113 [00:00<00:00, 119.81it/s]


Epoch 23/100: (Train Reward) 0.10; (Train MSE) 0.35; (Val Reward) 0.09; (Val MSE) 0.35


100%|██████████| 113/113 [00:00<00:00, 116.61it/s]


Epoch 24/100: (Train Reward) 0.01; (Train MSE) 0.40; (Val Reward) -0.00; (Val MSE) 0.40


100%|██████████| 113/113 [00:00<00:00, 118.70it/s]


Epoch 25/100: (Train Reward) -0.15; (Train MSE) 0.44; (Val Reward) -0.16; (Val MSE) 0.45


100%|██████████| 113/113 [00:00<00:00, 118.33it/s]


Epoch 26/100: (Train Reward) 0.03; (Train MSE) 0.38; (Val Reward) 0.01; (Val MSE) 0.38


100%|██████████| 113/113 [00:01<00:00, 103.46it/s]


Epoch 27/100: (Train Reward) 0.03; (Train MSE) 0.37; (Val Reward) 0.01; (Val MSE) 0.38


100%|██████████| 113/113 [00:00<00:00, 119.82it/s]


Epoch 28/100: (Train Reward) 0.13; (Train MSE) 0.35; (Val Reward) 0.11; (Val MSE) 0.36


100%|██████████| 113/113 [00:00<00:00, 124.66it/s]


Epoch 29/100: (Train Reward) -0.00; (Train MSE) 0.39; (Val Reward) -0.01; (Val MSE) 0.40


100%|██████████| 113/113 [00:00<00:00, 136.58it/s]


Epoch 30/100: (Train Reward) 0.11; (Train MSE) 0.35; (Val Reward) 0.09; (Val MSE) 0.35


100%|██████████| 113/113 [00:00<00:00, 114.63it/s]


Epoch 31/100: (Train Reward) 0.35; (Train MSE) 0.35; (Val Reward) 0.32; (Val MSE) 0.35


100%|██████████| 113/113 [00:00<00:00, 125.88it/s]


Epoch 32/100: (Train Reward) 0.14; (Train MSE) 0.33; (Val Reward) 0.13; (Val MSE) 0.34


100%|██████████| 113/113 [00:00<00:00, 113.79it/s]


Epoch 33/100: (Train Reward) -0.17; (Train MSE) 0.44; (Val Reward) -0.18; (Val MSE) 0.45


100%|██████████| 113/113 [00:01<00:00, 112.61it/s]


Epoch 34/100: (Train Reward) 0.18; (Train MSE) 0.35; (Val Reward) 0.17; (Val MSE) 0.35


100%|██████████| 113/113 [00:01<00:00, 112.31it/s]


Epoch 35/100: (Train Reward) -0.04; (Train MSE) 0.39; (Val Reward) -0.05; (Val MSE) 0.41


100%|██████████| 113/113 [00:01<00:00, 108.39it/s]


Epoch 36/100: (Train Reward) 0.02; (Train MSE) 0.38; (Val Reward) 0.01; (Val MSE) 0.39


100%|██████████| 113/113 [00:01<00:00, 107.39it/s]


Epoch 37/100: (Train Reward) 0.15; (Train MSE) 0.33; (Val Reward) 0.14; (Val MSE) 0.33


100%|██████████| 113/113 [00:01<00:00, 109.03it/s]


Epoch 38/100: (Train Reward) 0.06; (Train MSE) 0.36; (Val Reward) 0.05; (Val MSE) 0.36


100%|██████████| 113/113 [00:01<00:00, 108.24it/s]


Epoch 39/100: (Train Reward) -0.03; (Train MSE) 0.37; (Val Reward) -0.05; (Val MSE) 0.37


100%|██████████| 113/113 [00:01<00:00, 109.79it/s]


Epoch 40/100: (Train Reward) -0.07; (Train MSE) 0.38; (Val Reward) -0.08; (Val MSE) 0.39


100%|██████████| 113/113 [00:01<00:00, 109.41it/s]


Epoch 41/100: (Train Reward) 0.15; (Train MSE) 0.35; (Val Reward) 0.13; (Val MSE) 0.35


100%|██████████| 113/113 [00:00<00:00, 113.90it/s]


Epoch 42/100: (Train Reward) 0.07; (Train MSE) 0.38; (Val Reward) 0.06; (Val MSE) 0.39


100%|██████████| 113/113 [00:01<00:00, 111.59it/s]


Epoch 43/100: (Train Reward) 0.17; (Train MSE) 0.34; (Val Reward) 0.15; (Val MSE) 0.35


100%|██████████| 113/113 [00:01<00:00, 98.90it/s] 


Epoch 44/100: (Train Reward) 0.13; (Train MSE) 0.33; (Val Reward) 0.11; (Val MSE) 0.34


100%|██████████| 113/113 [00:00<00:00, 118.40it/s]


Epoch 45/100: (Train Reward) 0.10; (Train MSE) 0.33; (Val Reward) 0.09; (Val MSE) 0.34


100%|██████████| 113/113 [00:00<00:00, 122.82it/s]


Epoch 46/100: (Train Reward) -0.05; (Train MSE) 0.40; (Val Reward) -0.06; (Val MSE) 0.41


100%|██████████| 113/113 [00:00<00:00, 121.71it/s]


Epoch 47/100: (Train Reward) -0.06; (Train MSE) 0.39; (Val Reward) -0.07; (Val MSE) 0.39


100%|██████████| 113/113 [00:00<00:00, 120.63it/s]


Epoch 48/100: (Train Reward) -0.07; (Train MSE) 0.39; (Val Reward) -0.08; (Val MSE) 0.40


100%|██████████| 113/113 [00:00<00:00, 117.81it/s]


Epoch 49/100: (Train Reward) 0.32; (Train MSE) 0.34; (Val Reward) 0.31; (Val MSE) 0.34


100%|██████████| 113/113 [00:00<00:00, 120.60it/s]


Epoch 50/100: (Train Reward) -0.00; (Train MSE) 0.36; (Val Reward) -0.02; (Val MSE) 0.37


100%|██████████| 113/113 [00:00<00:00, 120.28it/s]


Epoch 51/100: (Train Reward) -0.08; (Train MSE) 0.38; (Val Reward) -0.10; (Val MSE) 0.40


100%|██████████| 113/113 [00:00<00:00, 117.93it/s]


Epoch 52/100: (Train Reward) 0.10; (Train MSE) 0.32; (Val Reward) 0.09; (Val MSE) 0.33


100%|██████████| 113/113 [00:00<00:00, 120.39it/s]


Epoch 53/100: (Train Reward) 0.06; (Train MSE) 0.33; (Val Reward) 0.05; (Val MSE) 0.34


100%|██████████| 113/113 [00:00<00:00, 120.43it/s]


Epoch 54/100: (Train Reward) 0.11; (Train MSE) 0.32; (Val Reward) 0.10; (Val MSE) 0.33


100%|██████████| 113/113 [00:00<00:00, 118.37it/s]


Epoch 55/100: (Train Reward) 0.01; (Train MSE) 0.36; (Val Reward) -0.00; (Val MSE) 0.37


100%|██████████| 113/113 [00:00<00:00, 121.32it/s]


Epoch 56/100: (Train Reward) -0.02; (Train MSE) 0.37; (Val Reward) -0.03; (Val MSE) 0.38


100%|██████████| 113/113 [00:00<00:00, 121.71it/s]


Epoch 57/100: (Train Reward) 0.08; (Train MSE) 0.33; (Val Reward) 0.08; (Val MSE) 0.34


100%|██████████| 113/113 [00:00<00:00, 114.46it/s]


Epoch 58/100: (Train Reward) -0.04; (Train MSE) 0.37; (Val Reward) -0.05; (Val MSE) 0.38


100%|██████████| 113/113 [00:01<00:00, 112.83it/s]


Epoch 59/100: (Train Reward) -0.14; (Train MSE) 0.41; (Val Reward) -0.16; (Val MSE) 0.43


100%|██████████| 113/113 [00:01<00:00, 109.39it/s]


Epoch 60/100: (Train Reward) 0.15; (Train MSE) 0.31; (Val Reward) 0.15; (Val MSE) 0.32


100%|██████████| 113/113 [00:01<00:00, 93.49it/s] 


Epoch 61/100: (Train Reward) -0.16; (Train MSE) 0.43; (Val Reward) -0.18; (Val MSE) 0.44


100%|██████████| 113/113 [00:01<00:00, 107.60it/s]


Epoch 62/100: (Train Reward) 0.04; (Train MSE) 0.34; (Val Reward) 0.03; (Val MSE) 0.35


100%|██████████| 113/113 [00:01<00:00, 111.73it/s]


Epoch 63/100: (Train Reward) 0.12; (Train MSE) 0.33; (Val Reward) 0.12; (Val MSE) 0.34


100%|██████████| 113/113 [00:00<00:00, 115.42it/s]


Epoch 64/100: (Train Reward) 0.07; (Train MSE) 0.34; (Val Reward) 0.05; (Val MSE) 0.34


100%|██████████| 113/113 [00:00<00:00, 113.40it/s]


Epoch 65/100: (Train Reward) -0.03; (Train MSE) 0.37; (Val Reward) -0.04; (Val MSE) 0.38


100%|██████████| 113/113 [00:00<00:00, 114.72it/s]


Epoch 66/100: (Train Reward) -0.01; (Train MSE) 0.35; (Val Reward) -0.02; (Val MSE) 0.36


100%|██████████| 113/113 [00:01<00:00, 109.49it/s]


Epoch 67/100: (Train Reward) -0.02; (Train MSE) 0.38; (Val Reward) -0.03; (Val MSE) 0.39


100%|██████████| 113/113 [00:00<00:00, 113.17it/s]


Epoch 68/100: (Train Reward) 0.12; (Train MSE) 0.33; (Val Reward) 0.11; (Val MSE) 0.34


100%|██████████| 113/113 [00:01<00:00, 108.08it/s]


Epoch 69/100: (Train Reward) 0.16; (Train MSE) 0.31; (Val Reward) 0.15; (Val MSE) 0.32


100%|██████████| 113/113 [00:01<00:00, 94.88it/s] 


Epoch 70/100: (Train Reward) 0.05; (Train MSE) 0.34; (Val Reward) 0.04; (Val MSE) 0.36


100%|██████████| 113/113 [00:01<00:00, 110.79it/s]


Epoch 71/100: (Train Reward) -0.17; (Train MSE) 0.43; (Val Reward) -0.19; (Val MSE) 0.45


100%|██████████| 113/113 [00:00<00:00, 121.78it/s]


Epoch 72/100: (Train Reward) -0.13; (Train MSE) 0.40; (Val Reward) -0.14; (Val MSE) 0.42


100%|██████████| 113/113 [00:00<00:00, 119.05it/s]


Epoch 73/100: (Train Reward) -0.08; (Train MSE) 0.38; (Val Reward) -0.09; (Val MSE) 0.39


100%|██████████| 113/113 [00:00<00:00, 119.15it/s]


Epoch 74/100: (Train Reward) 0.10; (Train MSE) 0.32; (Val Reward) 0.09; (Val MSE) 0.33


100%|██████████| 113/113 [00:00<00:00, 117.95it/s]


Epoch 75/100: (Train Reward) -0.08; (Train MSE) 0.38; (Val Reward) -0.09; (Val MSE) 0.39


100%|██████████| 113/113 [00:00<00:00, 119.90it/s]


Epoch 76/100: (Train Reward) -0.17; (Train MSE) 0.42; (Val Reward) -0.18; (Val MSE) 0.44


100%|██████████| 113/113 [00:00<00:00, 118.05it/s]


Epoch 77/100: (Train Reward) 0.02; (Train MSE) 0.32; (Val Reward) 0.01; (Val MSE) 0.34


100%|██████████| 113/113 [00:00<00:00, 119.43it/s]


Epoch 78/100: (Train Reward) 0.13; (Train MSE) 0.30; (Val Reward) 0.13; (Val MSE) 0.31


100%|██████████| 113/113 [00:01<00:00, 101.45it/s]


Epoch 79/100: (Train Reward) 0.00; (Train MSE) 0.34; (Val Reward) -0.01; (Val MSE) 0.36


100%|██████████| 113/113 [00:00<00:00, 118.62it/s]


Epoch 80/100: (Train Reward) 0.17; (Train MSE) 0.33; (Val Reward) 0.16; (Val MSE) 0.33


100%|██████████| 113/113 [00:00<00:00, 119.19it/s]


Epoch 81/100: (Train Reward) 0.13; (Train MSE) 0.32; (Val Reward) 0.13; (Val MSE) 0.33


100%|██████████| 113/113 [00:00<00:00, 118.55it/s]


Epoch 82/100: (Train Reward) 0.00; (Train MSE) 0.36; (Val Reward) -0.00; (Val MSE) 0.37


100%|██████████| 113/113 [00:01<00:00, 111.76it/s]


Epoch 83/100: (Train Reward) -0.07; (Train MSE) 0.38; (Val Reward) -0.08; (Val MSE) 0.40


100%|██████████| 113/113 [00:01<00:00, 111.65it/s]


Epoch 84/100: (Train Reward) -0.05; (Train MSE) 0.37; (Val Reward) -0.05; (Val MSE) 0.39


100%|██████████| 113/113 [00:01<00:00, 108.55it/s]


Epoch 85/100: (Train Reward) -0.06; (Train MSE) 0.37; (Val Reward) -0.07; (Val MSE) 0.39


100%|██████████| 113/113 [00:01<00:00, 112.88it/s]


Epoch 86/100: (Train Reward) 0.06; (Train MSE) 0.33; (Val Reward) 0.05; (Val MSE) 0.34


100%|██████████| 113/113 [00:01<00:00, 111.99it/s]


Epoch 87/100: (Train Reward) -0.04; (Train MSE) 0.37; (Val Reward) -0.05; (Val MSE) 0.37


100%|██████████| 113/113 [00:01<00:00, 93.75it/s] 


Epoch 88/100: (Train Reward) 0.07; (Train MSE) 0.32; (Val Reward) 0.07; (Val MSE) 0.33


100%|██████████| 113/113 [00:01<00:00, 108.99it/s]


Epoch 89/100: (Train Reward) 0.16; (Train MSE) 0.31; (Val Reward) 0.16; (Val MSE) 0.32


100%|██████████| 113/113 [00:01<00:00, 111.20it/s]


Epoch 90/100: (Train Reward) 0.05; (Train MSE) 0.34; (Val Reward) 0.04; (Val MSE) 0.36


100%|██████████| 113/113 [00:01<00:00, 112.01it/s]


Epoch 91/100: (Train Reward) 0.09; (Train MSE) 0.33; (Val Reward) 0.09; (Val MSE) 0.34


100%|██████████| 113/113 [00:01<00:00, 110.90it/s]


Epoch 92/100: (Train Reward) 0.16; (Train MSE) 0.35; (Val Reward) 0.14; (Val MSE) 0.35


100%|██████████| 113/113 [00:01<00:00, 112.14it/s]


Epoch 93/100: (Train Reward) -0.08; (Train MSE) 0.38; (Val Reward) -0.09; (Val MSE) 0.39


100%|██████████| 113/113 [00:00<00:00, 113.89it/s]


Epoch 94/100: (Train Reward) 0.12; (Train MSE) 0.32; (Val Reward) 0.12; (Val MSE) 0.33


100%|██████████| 113/113 [00:00<00:00, 116.22it/s]


Epoch 95/100: (Train Reward) 0.18; (Train MSE) 0.32; (Val Reward) 0.17; (Val MSE) 0.32


100%|██████████| 113/113 [00:00<00:00, 116.74it/s]


Epoch 96/100: (Train Reward) 0.03; (Train MSE) 0.34; (Val Reward) 0.03; (Val MSE) 0.36


100%|██████████| 113/113 [00:01<00:00, 102.60it/s]


Epoch 97/100: (Train Reward) 0.14; (Train MSE) 0.31; (Val Reward) 0.14; (Val MSE) 0.32


100%|██████████| 113/113 [00:00<00:00, 118.33it/s]


Epoch 98/100: (Train Reward) 0.08; (Train MSE) 0.33; (Val Reward) 0.07; (Val MSE) 0.34


100%|██████████| 113/113 [00:00<00:00, 118.60it/s]


Epoch 99/100: (Train Reward) -0.31; (Train MSE) 0.49; (Val Reward) -0.33; (Val MSE) 0.51


100%|██████████| 113/113 [00:00<00:00, 120.02it/s]


Epoch 100/100: (Train Reward) 0.01; (Train MSE) 0.34; (Val Reward) -0.00; (Val MSE) 0.35
