In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
from collections import namedtuple, deque

In [2]:
device = torch.device('cuda')

In [3]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

In [4]:
class ReplayMemory(object):
    
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
    
    def push(self, *args):
        self.memory.append(Transition(*args))
        
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [5]:
batch_size = 3

In [6]:
class Generator(nn.Module):
    def __init__(self, n_states):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(n_states, 50),
            nn.LeakyReLU(0.01, inplace=True),
            nn.Linear(50, n_states),
        )
        
    def forward(self, x):
        x = self.model(x)
        return x

In [7]:
class Discriminator(nn.Module):
    def __init__(self, n_states):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(n_states, 50),
            nn.LeakyReLU(0.01, inplace=True),
            nn.Linear(50, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.model(x)
        return x

In [8]:
batch_size = 3
n_states = 10
lr = 0.001

In [9]:
rnd = np.random.choice(np.arange(10), batch_size)
rnd

array([9, 2, 1])

In [10]:
torch.from_numpy(rnd)

tensor([9, 2, 1])

In [11]:
F.one_hot(torch.from_numpy(rnd), num_classes=10)

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]])

In [12]:
D = Discriminator(n_states)
D_opt = torch.optim.Adam(D.parameters(), lr=lr)
G = Generator(n_states)
G_opt = torch.optim.Adam(G.parameters(), lr=lr)
criterion = nn.BCELoss()

In [13]:
x = F.one_hot(torch.from_numpy(rnd), num_classes=10)
x = x.float()

In [14]:
D(x)

tensor([[0.5473],
        [0.5301],
        [0.5383]], grad_fn=<SigmoidBackward0>)

In [15]:
labels = torch.ones(batch_size)

In [16]:
noise = torch.randn(batch_size, n_states)
G(noise)

tensor([[ 0.2442,  0.2480,  0.2726,  0.2822, -0.1477,  0.0340, -0.0071, -0.4850,
          0.1369,  0.5594],
        [-0.0065, -0.0866, -0.0385, -0.0149, -0.0026,  0.0224, -0.0984,  0.2024,
         -0.0064,  0.2346],
        [ 0.3048,  0.5615,  0.4681,  0.2669,  0.1271,  0.3877, -0.3517, -0.5092,
          0.4134,  0.6463]], grad_fn=<AddmmBackward0>)

In [17]:
D.zero_grad()
x = F.one_hot(torch.from_numpy(rnd), num_classes=10).float()
labels = torch.ones(batch_size)
out = D(x)
print(out, labels.unsqueeze(1))
D_loss = criterion(out, labels.unsqueeze(1))
D_loss.backward()

noise = torch.randn(batch_size, n_states)
fake = G(noise)
labels.fill_(0)

out = D(fake.detach())
D_loss = criterion(out, labels.unsqueeze(1))

D_loss.backward()
D_opt.step()

G.zero_grad()
labels.fill_(1)
out = D(fake)
G_loss = criterion(out, labels.unsqueeze(1))
G_loss.backward()
G_opt.step()

tensor([[0.5473],
        [0.5301],
        [0.5383]], grad_fn=<SigmoidBackward0>) tensor([[1.],
        [1.],
        [1.]])
