In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import seaborn as sns


class MatrixEnvironment:
    def __init__(self, m, n, p):
        self.m = m
        self.n = n
        self.p = p
        self.state = torch.zeros(self.m, self.n, dtype=torch.float32, requires_grad=True)

    def get_state(self):
        return self.state

    def step(self, action_matrix):
        self.state = action_matrix
        # random_vector = torch.tensor(np.random.choice([0, 1], size=(self.m,), p=[1-self.p, self.p]), dtype=torch.float32, requires_grad=True)
        reward = self.calculate_reward(action_matrix)
        done = True  # Episode is done after one step in this simple environment
        return reward

    def calculate_reward(self, X, lost_penalty=40, a=1, l=0, o=100, s=1):
        m_nr, t_nr = X.shape
        m = torch.tensor(np.ones(t_nr)*self.p, dtype=torch.float32, requires_grad=True)
        t = torch.tensor(np.ones(t_nr), dtype=torch.float32, requires_grad=True) #if t is None else t
        mm = torch.zeros(t_nr, dtype=torch.float32)
        for tag in range(t_nr):
            mask = m[X[:, tag] == 1]  # mask for tag
            mm[tag] = torch.prod(mask)  # multiply the mask
        A = torch.matmul(X, (mm * t))
        # A = torch.where(A > 1, torch.ones_like(A), A)

        # L = [torch.where(X[:, torch.where((X[msg, :] * (mm * t)) > 0)[0][0]] == 1)[0][-1].item() - msg if A[msg] > 0 else lost_penalty for msg in range(m_nr)]
        # L = torch.tensor(L, dtype=torch.float32)

        return a * torch.sum(A) #+ l * torch.sum(L)




In [2]:
class FullyConnectedPolicyNetwork(nn.Module):
    def __init__(self, m, n):
        super(FullyConnectedPolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(m * n, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, m * n)
        self.m = m
        self.n = n

    def forward(self, x):
        x = torch.flatten(x)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = torch.sigmoid(self.fc4(x))
        return x.view(self.m, self.n)
    


class DeepPolicyNetworkCNN(nn.Module):
    def __init__(self, m, n):
        super(DeepPolicyNetworkCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * m * n, 128)
        self.fc2 = nn.Linear(128, m * n)
        self.m = m
        self.n = n

    def forward(self, x):
        x = x.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(-1, 64 * self.m * self.n)
        x = torch.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x.view(self.m, self.n)


# def gumbel_softmax(logits, tau=1, hard=False):
#     gumbels = -torch.empty_like(logits).exponential_().log()  # ~Gumbel(0,1)
#     gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
#     y_soft = gumbels.softmax(dim=-1)

#     if hard:
#         # Straight through.
#         index = y_soft.max(dim=-1, keepdim=True)[1]
#         y_hard = torch.zeros_like(logits).scatter_(-1, index, 1.0)
#         ret = y_hard - y_soft.detach() + y_soft
#     else:
#         # Reparameterization trick.
#         ret = y_soft
#     return ret

In [3]:
# Example reward function usage
# m = torch.tensor([1, 1, 0, 1], requires_grad=True, dtype=torch.float32)

m, n = 4, 4
p = 0.9


X = torch.tensor([[1, 0, 0, 0],
                  [1, 1, 0, 0],
                  [0, 0, 1, 0],
                  [0, 0, 0, 1]], requires_grad=True, dtype=torch.float32)

env = MatrixEnvironment(m, n, p)
policy_net = FullyConnectedPolicyNetwork(m, n)
optimizer = optim.Adam(policy_net.parameters(), lr=.001)

# env.step(X)
    
for episode in range(100):
    optimizer.zero_grad()

    X = env.get_state()
    X.retain_grad()

    r= env.step(X)
    loss = -r
    r.backward(retain_graph= True)
    # print(X.grad)

    optimizer.step()
    
    X = policy_net(X)
    X = torch.bernoulli(X)
    X.retain_grad()
    

