# Causal Knowledge Transfer for Safe Reinforcement Learning

### Train Agents
#### Sumo Network File
When editing the sumo network (nets/simple_unprotected_right.net.xml) never edit the xml directly. Instead, go to nets/netconfig and make desired changes there. Generate the new net.xml by executing generate config.sh
#### Sumo Route File
This is part of the generate_config.sh now.
#### Reward Function
* TODO: Come up with a fitting reward function that penalises collisions
#### RL Training
* TODO: Come up with Hyperparameters for the training loop

**Desired Output: Trained_Model.zip**

### Creating the Sumo Environment

In [1]:
from pathlib import Path
import xml.etree.ElementTree as ET

# Environment Variables (Ensure SUMO_HOME is properly set)
sumo_home = %env SUMO_HOME

# Constants / Parameters
SPEED = 22.22
FRICTION = 1.0
INSERT_PROBABILITY = 0.1
DURATION = 3600
REPEAT_PERIOD = 10
DEFAULT_DECEL = 4.5
DEFAULT_EMERGENCY_DECEL = 9.0

# File Paths
config_directory = Path().joinpath('nets', '2lane_unprotected_right')
config_files = {
    'netccfg_edges': config_directory.joinpath('netconfig', 'edges.edg.xml'),
    'netccfg': config_directory.joinpath('2lane_unprotected_right.netccfg'),
    'duarcfg': config_directory.joinpath('2lane_unprotected_right.duarcfg'),
    'net.xml': config_directory.joinpath('2lane_unprotected_right.net.xml'),
    'rou.xml': config_directory.joinpath('2lane_unprotected_right.rou.xml'),
    'routes.rou.xml': config_directory.joinpath('routes.rou.xml'),
    'config.rou.xml': config_directory.joinpath('config.rou.xml'),
    'experimental.rou.xml': config_directory.joinpath('experimental.rou.xml'),
}
findAllRoutes = Path(sumo_home).joinpath('tools', 'findAllRoutes.py')
vehicle2flow = Path(sumo_home).joinpath('tools', 'route', 'vehicle2flow.py')


# Update Friction Coefficients in Edge Configuration
def update_friction_coefficients(file_path, friction):
    edges_xml_tree = ET.parse(file_path)
    edges_xml_root = edges_xml_tree.getroot()
    for param in edges_xml_root.findall(".//lane/param[@key='frictionCoefficient']"):
        param.set('value', str(friction))
    edges_xml_tree.write(file_path)


update_friction_coefficients(config_files['netccfg_edges'], FRICTION)

# Execute SUMO Tools
! netconvert --configuration-file {config_files['netccfg']}
! python {findAllRoutes} -n {config_files['net.xml']} -o {config_files['routes.rou.xml']} -s southJunction,westJunction -t junctionEast,junctionNorth
! duarouter --configuration-file {config_files['duarcfg']}
! python {vehicle2flow} {config_files['config.rou.xml']} -o {config_files['rou.xml']} -e {DURATION} -r {REPEAT_PERIOD}

# Update Vehicle Configuration for Friction Adjusted Braking Distance
def update_vehicle_type_parameters(file_path, speed, default_decel, default_emergency_decel, friction):
    tree = ET.parse(file_path)
    root = tree.getroot()
    for vType in root.findall('vType'):
        vClass = vType.attrib.get('vClass')
        if vClass and vClass != 'passenger':
            raise NotImplementedError("Check for non-passenger vehicle classes not implemented")
        vType.attrib.update({
            'maxSpeed': str(speed),
            'decel': str(default_decel * friction),
            'emergencyDecel': str(default_emergency_decel * friction),
        })
    tree.write(file_path, xml_declaration=True, encoding='UTF-8')


update_vehicle_type_parameters(config_files['rou.xml'], SPEED, DEFAULT_DECEL, DEFAULT_EMERGENCY_DECEL, FRICTION)


# Update Vehicle Flows for Forcing Unprotected Right Action
def update_flows(file_path, insert_probability):
    tree = ET.parse(file_path)
    root = tree.getroot()
    for flow in root.findall('flow'):
        match flow.attrib.get('id'):
            case 'southEast':
                flow.set('period', f"exp({insert_probability})")
            case 'southNorth':
                flow.set('period', f"exp({insert_probability})")
            case 'westEastTop':
                flow.set('period', f"exp({2 * insert_probability})")
            case id if 'westEastBottom' in id:
                flow.set('end', str(float(flow.get('begin')) + 600))
                if float(flow.get('begin')) % 1200 == 0:
                    flow.set('period', f"exp({1 * insert_probability})")
                else:
                    flow.set('period', f"exp({0.0001 * insert_probability})")
    tree.write(file_path, xml_declaration=True, encoding='UTF-8')


update_flows(config_files['rou.xml'], INSERT_PROBABILITY)

Success.
Success.up to time step: 3200.00


In [2]:
from env.SumoEnvironmentGenerator import SumoEnvironmentGenerator
from pathlib import Path

net_name = '2lane_unprotected_right'

environments = SumoEnvironmentGenerator(
    net_file=str(Path().joinpath('nets', net_name, f'{net_name}.net.xml')),
    route_file=str(Path().joinpath('nets', net_name, f'{net_name}.rou.xml')),
    sumocfg_file=str(Path().joinpath('nets', net_name, f'{net_name}.sumocfg')),
    duration=3600,
    learning_data_csv_name=str(Path().joinpath('env', 'training_data', 'output.csv')),
)

### Training and saving the Model

In [None]:
from stable_baselines3.a2c import A2C

%load_ext tensorboard
env = environments.get_training_env()
model = A2C(
    env=env,
    policy='MlpPolicy',
    n_steps=100,
    # learning_rate=0.001,
    # learning_starts=0,
    # train_freq=1,
    # target_update_interval=500,
    # exploration_fraction=0.05,
    # exploration_final_eps=0.01,
    verbose=1,
    tensorboard_log='dqn_sumo_tensorboard'
)
model.learn(100_000, tb_log_name='a2c_100step_2lane_1delta_minute_1traffic_50speed_new')
model.save(Path().joinpath('env', 'training_data_2lane', 'a2c_100step_2lane_1delta_minute_1traffic_50speed_new'))

Giving the model a test run in an evaluation environment

In [3]:
#63 (ns green), 86 (incoming collision), 134 (we green), 0739 unprotected good

from stable_baselines3.a2c import A2C
from pathlib import Path

env = environments.get_demonstration_env()
model = A2C(env=env, policy='MlpPolicy').load(
    Path().joinpath('env', 'agents_paper', 'scratch_s80_f0.5.zip'))

rewards = []
actions = []
obs, info = env.reset()

gui = env.sumo.gui
gui.setBoundary(gui.DEFAULT_VIEW, 480.0, 480.0, 520.0, 520.0)
# gui.addView('3D', in3D=True)
screenshots_path_jpg = Path().joinpath('screenshots', 'tmp', 'jpg')
screenshots_path_jpg.mkdir(parents=True, exist_ok=True)
screenshots_path_svg = Path().joinpath('screenshots', 'tmp', 'svg')
screenshots_path_svg.mkdir(parents=True, exist_ok=True)

#vehicletype = env.sumo.vehicletype
#vehicletype.setMaxSpeed('carCustom', SPEED)
#vehicletype.setDecel('carCustom', vehicletype.getDecel('carCustom') * FRICTION)
#vehicletype.setEmergencyDecel('carCustom',
#                              vehicletype.getEmergencyDecel('carCustom') * FRICTION)
#for traffic_signal in env.traffic_signals.values():
#    for lane in traffic_signal.lanes:
#        traffic_signal.sumo.lane.setParameter(lane, 'frictionCoefficient', FRICTION)

done = False
while not done:
    action, _state = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)
    rewards.append(reward)
    actions.append(action)
    done = terminated or truncated
    current_step = int(env.sumo.simulation.getTime())
    screenshot_path_jpg = screenshots_path_jpg.joinpath(str(current_step).zfill(4) + '.jpg')
    screenshot_path_svg = screenshots_path_svg.joinpath(str(current_step).zfill(4) + '.svg')
    gui.screenshot(gui.DEFAULT_VIEW, str(screenshot_path_jpg))
    gui.screenshot(gui.DEFAULT_VIEW, str(screenshot_path_svg))
env.close()

print(rewards)
print(" ")
print(actions)

 Retrying in 1 seconds
Step #0.00 (0ms ?*RT. ?UPS, TraCI: 3ms, vehicles TOT 0 ACT 0 BUF 0)                      
 Retrying in 1 seconds
Could not connect to TraCI server at localhost:61721 [Errno 61] Connection refused
 Retrying in 1 seconds
Could not connect to TraCI server at localhost:61721 [Errno 61] Connection refused
 Retrying in 1 seconds
Could not connect to TraCI server at localhost:61721 [Errno 61] Connection refused
 Retrying in 1 seconds
Could not connect to TraCI server at localhost:61721 [Errno 61] Connection refused
 Retrying in 1 seconds
Could not connect to TraCI server at localhost:61721 [Errno 61] Connection refused
 Retrying in 1 seconds
[-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0

### Produce Traces
Run the simulation repeatably to produce traces (data) for causal discovery.

#### Data Selection
TODO: Select which columns we want to do Causal Discovery on
#### Data Summary
TODO: Incorporate old data summary script

**Desired Output: One CSV File containing all interesting data**


In [None]:
import itertools
from env.SumoTraceGenerator import SumoTraceGenerator
from pathlib import Path
from tqdm.notebook import tqdm

speeds = [30, 50, 100]
frictions = [1.0, 0.8, 0.5]

experiments = list(itertools.product(speeds, frictions))

for speed, friction in tqdm(experiments):
    simulation_output_path = Path().joinpath('data', f'a2c_{int(speed)}_f{friction}')
    Path.mkdir(simulation_output_path, parents=True, exist_ok=True)

    trace_generator = SumoTraceGenerator()
    trace_generator.generate_traces(
        env_generator=environments,
        path=simulation_output_path,
        size=100,
        speed_loc=float(speed) / 3.6,
        friction_log=friction,
        friction_scale=0.1
    )

#### Data Summary

In [3]:
from statistics import mean
import glob
from pathlib import Path
import xml.etree.ElementTree as ElementTree
import pandas as pd

data_folder = Path().joinpath('traces_paper')
experiments = glob.glob(str(data_folder.joinpath('*')))

for experiment in experiments:
    experiment_path = Path(experiment)
    statistics_files = glob.glob(str(experiment_path.joinpath('*_statistics.xml')))
    ids = [path.split('/')[-1].split('_')[0] for path in statistics_files]

    data = []
    for id in ids:
        statistics_file = experiment_path.joinpath(id + '_statistics.xml')
        collisions_file = experiment_path.joinpath(id + '_collisions.xml')
        ssm_file = experiment_path.joinpath(id + '_ssm.xml')
        metadata_file = experiment_path.joinpath(id + '_metadata.xml')
        tripinfo_file = experiment_path.joinpath(id + '_tripinfo.xml')
        statistics_xml = ElementTree.parse(statistics_file).getroot()
        collisions_xml = ElementTree.parse(collisions_file).getroot()
        ssm_xml = ElementTree.parse(ssm_file).getroot()
        metadata_xml = ElementTree.parse(metadata_file).getroot()
        tripinfo_xml = ElementTree.parse(tripinfo_file).getroot()

        row = {
            'experiment': experiment,
            'index': int(id),
            'desiredSpeed': float(metadata_xml.find('.//desiredSpeed').text),
            'friction': float(metadata_xml.find('.//friction').text)
        }

        for key, value in {**statistics_xml.find('vehicleTripStatistics').attrib,
                           **statistics_xml.find('safety').attrib}.items():
            match key:
                case 'count' | 'emergencyStops' | 'emergencyBraking':
                    row[key] = int(value)
                case 'collisions':
                    row[key] = int(value)
                    row['rearEndCollisions'] = sum(
                        'southEast' in child.attrib.get('victim') for child in collisions_xml)
                    row['lateralCollisions'] = sum(
                        'southEast' in child.attrib.get('collider') for child in collisions_xml)
                    row[key] = row['rearEndCollisions'] + row['lateralCollisions']
                case _:
                    row[key] = float(value)

        waiting_times = [float(tripinfo.attrib.get('waitingTime')) for tripinfo in tripinfo_xml.findall('tripinfo')]
        average_waiting_time = mean(waiting_times)
        row['waitingTime'] = average_waiting_time

        data.append(row)

        df = pd.DataFrame(data)
        df.to_csv(experiment_path.joinpath('.summary.csv'), index=False)

Some Visualization

In [None]:
import pandas as pd
from pathlib import Path
from matplotlib import pyplot as plt
import seaborn as sns

data = Path().joinpath('data_agent', 'data.csv')
paper_path = Path().joinpath('paper')
df = pd.read_csv(data)

sns.displot(data=df, x='friction', y='waitingTime', kind='kde')
plt.title('friction - waitingTime, all desiredSpeeds')
plt.savefig(paper_path.joinpath('friction-waitingTime.png'), bbox_inches='tight')

sns.displot(data=df, x='friction', y='collisions', kind='kde')
plt.title('friction - collisions, all desiredSpeeds')
plt.savefig(paper_path.joinpath('friction-collisions'), bbox_inches='tight')

sns.displot(data=df, x='desiredSpeed', y='waitingTime', kind='kde')
plt.title('desiredSpeed - waitingTime, all frictions')
plt.savefig(paper_path.joinpath('desiredSpeed-waitingTime'), bbox_inches='tight')

sns.displot(data=df, x='desiredSpeed', y='collisions', kind='kde')
plt.title('desiredSpeed - collisions, all frictions')
plt.savefig(paper_path.joinpath('desiredSpeed-collisions'), bbox_inches='tight')

sns.displot(data=df, x='friction', y='emergencyBraking', kind='kde')
plt.title('friction - emergencyBraking, all desiredSpeeds')
plt.savefig(paper_path.joinpath('friction-emergencyBraking'), bbox_inches='tight')

sns.displot(data=df, x='friction', y='speed', kind='kde')
plt.title('friction - speed, all desiredSpeeds')
plt.savefig(paper_path.joinpath('friction-speed'), bbox_inches='tight')

sns.displot(data=df, x='desiredSpeed', y='emergencyBraking', kind='kde')
plt.title('desiredSpeed - emergencyBraking, all frictions')
plt.savefig(paper_path.joinpath('desiredSpeed-emergencyBraking'), bbox_inches='tight')

sns.displot(data=df, x='desiredSpeed', y='speed', kind='kde')
plt.title('desiredSpeed - speed, all frictions')
plt.savefig(paper_path.joinpath('desiredSpeed-speed'), bbox_inches='tight')


Concatenate Summary CSVs

In [6]:
from pathlib import Path
import glob
import pandas as pd

data_folder = Path().joinpath('traces_paper')
summaries = glob.glob(str(data_folder.joinpath('*', '.summary.csv')))

dfs = []
for summary in summaries:
    agent = Path(summary).parent.name[0:-7]
    df = pd.read_csv(summary)
    df['agent'] = agent
    dfs.append(df)

final_df = pd.concat(dfs)[
    ['agent', 'desiredSpeed', 'friction', 'speed', 'waitingTime', 'emergencyBraking', 'collisions']]

final_df.to_csv(data_folder.joinpath('data.csv'), index=False)

### Causal Discovery
Discover causal graph

TODO: Decide which discovery algorithm to use

TODO: Figure out how to incorporate R code in Jupyter Notebook

**Desired Output: Causal Graph XML File**

In [None]:
# TODO: Find library to compute the adjustments.

from castle.common.priori_knowledge import PrioriKnowledge
import networkx as nx
import pandas as pd
import itertools
from castle.algorithms import PC, DirectLiNGAM
from pathlib import Path

data_csv = Path().joinpath('data_agent', 'data.csv')
df = pd.read_csv(data_csv)

columns = dict(enumerate(df.columns))
column_indexes = {value: int(key) for key, value in columns.items()}

priori_knowledge = PrioriKnowledge(len(columns))
independent_variables = ['desiredSpeed', 'friction']
outcome_variables = ['waitingTime', 'collisions']
forbidden_edges = [
    *({'source': column, 'target': 'desiredSpeed'} for column in column_indexes.keys() if column != 'desiredSpeed'),
    *({'source': column, 'target': 'friction'} for column in column_indexes.keys() if column != 'friction'),
    *({'source': source_node, 'target': target_node} for source_node, target_node in
      itertools.permutations(outcome_variables, 2))
]
priori_knowledge.add_forbidden_edges(
    [(column_indexes[edge['source']], column_indexes[edge['target']]) for edge in forbidden_edges])

pc = PC(variant='stable', priori_knowledge=priori_knowledge)
pc.learn(df.values.tolist())

G_PC = nx.DiGraph(pc.causal_matrix)
H_PC = nx.relabel_nodes(G_PC, dict(enumerate(df.columns)))

color_map = [
    'green' if node in independent_variables else
    'red' if H_PC.out_degree(node) == 0 else
    'yellow' for node in H_PC.nodes
]

nx.draw(G=H_PC, node_color=color_map, node_size=1200, arrowsize=30, with_labels=True,
        pos=nx.circular_layout(H_PC))


Compute directionality of arrows

In [None]:
from matplotlib import pyplot as plt
from scipy.stats import spearmanr
from sklearn.linear_model import LinearRegression
import seaborn as sns

direction_adjusted_edge_confidence = {}

for edge in H_PC.edges():

    X = edge[0]
    Y = edge[1]
    assert H_PC.has_edge(X, Y)

    X_data = df[X].values.reshape(-1, 1)
    Y_data = df[Y].values.reshape(-1, 1)

    # X = f(Y)
    model_X_Y = LinearRegression()
    model_X_Y.fit(Y_data, X_data)
    X_pred = model_X_Y.predict(Y_data)
    residuals_X = X_data - X_pred

    # Y = g(X)
    model_Y_X = LinearRegression()
    model_Y_X.fit(X_data, Y_data)
    Y_pred = model_Y_X.predict(X_data)
    residuals_Y = Y_data - Y_pred

    #Spearman's rank correlation
    corr_X_resid, p_X_resid = spearmanr(X_data.ravel(), residuals_Y.ravel())
    corr_Y_resid, p_Y_resid = spearmanr(Y_data.ravel(), residuals_X.ravel())
    print(f'corr_{X}_resid:', corr_X_resid)
    print(f'corr_{Y}_resid:', corr_Y_resid)

    confidence_abs = abs(abs(corr_Y_resid) - abs(corr_X_resid))
    if abs(corr_X_resid) < abs(corr_Y_resid):
        print(f"{X} -> {Y}")
        direction_adjusted_edge_confidence[(X, Y)] = confidence_abs
    else:
        print(f"{Y} -> {X}")
        direction_adjusted_edge_confidence[(Y, X)] = confidence_abs

    # fig, axs = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True)

    # Residuals of waitingTime on speed
    #sns.scatterplot(x=X_data.ravel(), y=residuals_Y.ravel(), ax=axs[0])
    #axs[0].set_title(f"Residuals of {Y} ~ {X}")
    #axs[0].set_xlabel(f"{X}")
    #axs[0].set_ylabel(f"Residuals ({Y})")

    # Residuals of speed on waitingTime
    #sns.scatterplot(x=Y_data.ravel(), y=residuals_X.ravel(), ax=axs[1])
    #axs[1].set_title(f"Residuals of {X} ~ {Y}")
    #axs[1].set_xlabel(f"{Y}")
    #axs[1].set_ylabel(f"Residuals ({X})")

    print("==============================")

direction_adjusted_edges = []
for edge, confidence in direction_adjusted_edge_confidence.items():
    if edge not in H_PC.edges():
        if confidence < 0.1:
            edge = edge[::-1]
    direction_adjusted_edges.append(edge)

I_PC = H_PC.copy()
I_PC.clear_edges()

I_PC.add_edges_from(direction_adjusted_edges)
print(direction_adjusted_edge_confidence)

color_map = [
    'green' if node in independent_variables else
    'red' if I_PC.out_degree(node) == 0 else
    'yellow' for node in H_PC.nodes
]

nx.draw(G=I_PC, node_color=color_map, node_size=1200, arrowsize=30, with_labels=True,
        pos=nx.circular_layout(H_PC))

Compute Adjustments

### Fit MLMs
Fit MLMs based on Causal Discovery Graph

TODO: Parse Graph XML into MLM parameters / formulae

**Desired Output: MLM**

In [None]:
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression

linear_models = {}

X_train, X_test, y_train, y_test = train_test_split(df, df, test_size=0.2)

for node in H_PC.nodes:
    if H_PC.in_degree(node) > 0:
        lm_independent_variables = list(H_PC.predecessors(node))
        lm_target_variable = node

        X_node = X_train[list(set(lm_independent_variables + independent_variables))]
        y_node = y_train[lm_target_variable]

        model = LinearRegression()
        model.fit(X_node[lm_independent_variables], y_node)

        model_simple = LinearRegression()
        model_simple.fit(X_node[independent_variables], y_node)

        y_pred = model.predict(X_test[lm_independent_variables])
        mse = mean_squared_error(y_test[lm_target_variable], y_pred)
        r2 = r2_score(y_test[lm_target_variable], y_pred)

        y_pred_simple = model_simple.predict(X_test[independent_variables])
        mse_simple = mean_squared_error(y_test[lm_target_variable], y_pred_simple)
        r2_simple = r2_score(y_test[lm_target_variable], y_pred_simple)

        linear_models[node] = {'causal_lm': model, 'simple_lm': model_simple}

In [None]:
causal_emergency_braking_lm = linear_models['emergencyBraking']['causal_lm']
# causal_time_loss_lm = linear_models['timeLoss']['causal_lm']
causal_rear_end_lm = linear_models['rearEndCollisions']['causal_lm']
causal_x_test = X_test.copy()

causal_emergency_braking = causal_emergency_braking_lm.predict(X_test[causal_emergency_braking_lm.feature_names_in_])
causal_x_test['emergencyBraking'] = causal_emergency_braking

# causal_time_loss = causal_time_loss_lm.predict(X_test[causal_time_loss_lm.feature_names_in_])
# causal_x_test['timeLoss'] = causal_time_loss

# causal_rear_end = causal_rear_end_lm.predict(X_test[causal_rear_end_lm.feature_names_in_])
# causal_x_test['rearEndCollisions'] = causal_rear_end

for node, model in linear_models.items():
    # Causal LM 
    model_causal = model['causal_lm']
    lm_independent_variables = model_causal.feature_names_in_
    lm_target_variable = node

    causal_y_pred = model_causal.predict(causal_x_test[lm_independent_variables])
    causal_mse = mean_squared_error(y_test[lm_target_variable], causal_y_pred)
    causal_r2 = r2_score(y_test[lm_target_variable], causal_y_pred)

    # Simple LM
    model_simple = model['simple_lm']
    y_pred_simple = model_simple.predict(X_test[independent_variables])
    mse_simple = mean_squared_error(y_test[lm_target_variable], y_pred_simple)
    r2_simple = r2_score(y_test[lm_target_variable], y_pred_simple)

    print("causal", node, model_causal.feature_names_in_, model_causal.coef_, "With: mse ", causal_mse, ", r2 ",
          causal_r2)
    print("simple", node, model_simple.feature_names_in_, model_simple.coef_, "With: mse ", mse_simple, ", r2 ",
          r2_simple)
    print(" ")

print("Done")

### Produce Interventions

#### Covariate Shift Distribution
* Create a distribution for the covariate (friction) shift
* Sample from distribution
    * Fulfill Assumption: sparse sample data is representative for covariate shift ground truth
* Produce Traces for sparse input data

#### Crank MLM the other way
* Calculate Intervention Distribution by inputting sparse data into MLM

**Desired Output: Intervention Distribution**

In [None]:
# TODO

### Generate Posterior Distributions
TODO: Generate Posterior Distributions without intervention

TODO: Generate Posterior Distributions with intervention

**Desired Output: Two XML Files**

In [None]:
# TODO

### Query
Compare Distributions and decide, which part of the model to retrain.

TODO: Classify the data / model in parts

In [None]:
# TODO

### Evaluation

#### Agent
compare new resulting agent (partially continued training depending on Query) to:
* Old agent (Lower performance bound)
* Completely newly trained agent (upper performance bound)
* (New Agent that is trained completely on new data (without Query))

#### Intervention
Function: Number of Covariate Shift Samples --> Wasserstein distance: Intervention vs. ground truth (distribution)

#### MLM
* Wasserstein Distance: Effect of Intervention vs. ground truth effect
* Maybe also as a function of the number of retrain samples



In [None]:
# TODO

### Ideas
* Maybe no change in friction but rather only in the requirements
* Sophisticated Query: Causal Graph of Transfer Learning --> Generate Posterior for different transfer learning options --> Rank and choose best.
* Collision / Penalty Factor for managing Safety/Performance tradeoff