In [None]:

import numpy as np

import torch
import torch.nn as nn
from torch.functional import F

from matplotlib import pyplot as plt
from src.environments.simple_microgrid import SimpleMicrogrid
from src.utils.tools import set_all_seeds, load_config, plot_rollout
torch.autograd.set_detect_anomaly(True)

In [None]:
set_all_seeds(0)
config = load_config("d_a2c_fed")
if torch.backends.mps.is_available() and  torch.backends.mps.is_built():
    torch.device("mps")
    print("MPS enabled")
elif torch.cuda.is_available() and config['agent']['enable_gpu'] :
    print("MPS not available, using CUDA")
    device = torch.device("cuda:0")
else:
    print("MPS and CUDA not available, using CPU")
    device = torch.device("cpu")
env = SimpleMicrogrid(config=config['env'])


In [None]:
'''
    Agent definitions
'''

class Actor(nn.Module):

    def __init__(self, obs_dim, attr_dim, act_dim, hidden_dim=64) -> None:

        super(Actor, self).__init__()

        self.input = nn.Linear(obs_dim + attr_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, act_dim)

    def forward(self, obs, attr):

        input = torch.cat([attr, obs], dim=2)
        input = F.relu(self.input(input))

        output = F.softmax(self.output(input), dim=2)

        return output

class Critic(nn.Module):

    def __init__(self, obs_dim, attr_dim, hidden_dim=64) -> None:

        super(Critic, self).__init__()

        self.input = nn.Linear(obs_dim + attr_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, 1)

    def forward(self, obs, attr):

        input = torch.cat([attr, obs], dim=3)

        output = F.relu(self.input(input))

        output = self.output(output)

        return output

In [None]:
local_steps = 100
