In [24]:
import torch
import torch.nn as nn
import torchmetrics
import torch.optim as optim
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions import Uniform
import matplotlib.pyplot as plt

In [2]:
D = 20
mus = torch.ones(D)
sigmas = torch.eye(D)

In [3]:
mvn = MultivariateNormal(mus, sigmas)
unif = Uniform(0, 1)

In [9]:
class Generator(nn.Module):
    def __init__(self, n_layer=2):        
        super(Generator, self).__init__()
        self.fc = nn.Sequential(*[
            nn.Linear(20, 20), nn.Tanh()]*n_layer)        
    def forward(self, X):
        return self.fc(X)

class Discriminator(nn.Module):
    def __init__(self, n_layer=2):        
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(*[
            nn.Linear(20, 20), nn.Tanh()]*n_layer)  
        self.cls = nn.Sequential(nn.Linear(20, 1), nn.Sigmoid())
        
    def forward(self, X):
        z = self.fc(X)
        out = self.cls(z)
        return out

In [15]:
accuracy = torchmetrics.Accuracy()

In [51]:
from torch.utils.data import TensorDataset, DataLoader
G = Generator().cuda()
D = Discriminator().cuda()
G_loss_mean = torchmetrics.MeanMetric()
D_loss_mean = torchmetrics.MeanMetric()
reals = mvn.sample((1600,))
real_loader = DataLoader(TensorDataset(reals), batch_size=2)
optG = optim.AdamW(G.parameters(), lr=0.001)
optD = optim.AdamW(D.parameters(), lr=0.001)

In [54]:
train_gen = True
for epoch_i in range(10):
    for (real_batch,) in real_loader:
        real_batch = real_batch.cuda()
        optD.zero_grad()
        z = unif.sample((32, 20)).cuda()
        fakes = G(z)
        loss_D = -torch.mean(D(real_batch)) + torch.mean(D(fakes.detach()))
        loss_D.backward()
        optD.step()
        D_loss_mean.update(loss_D.item())
        for p in D.parameters():
            p.data.clamp_(-0.1, 0.1)

        ## train generator
        if train_gen:    
            optG.zero_grad()
            loss_G = -torch.mean(D(fakes))
            loss_G.backward()
            optG.step()
            G_loss_mean.update(loss_G.item())     
    print(f"loss_D: {D_loss_mean.compute().item()}, loss_G: {G_loss_mean.compute().item()}")    


loss_D: -0.0039813625626266, loss_G: -0.48899349570274353
loss_D: -0.005479327403008938, loss_G: -0.49022579193115234
loss_D: -0.014151657931506634, loss_G: -0.4780244827270508
loss_D: -0.01222685631364584, loss_G: -0.48399192094802856
loss_D: -0.011000297032296658, loss_G: -0.478678435087204
loss_D: -0.009856012649834156, loss_G: -0.48085400462150574
loss_D: -0.009009427390992641, loss_G: -0.4768573045730591
loss_D: -0.008427336812019348, loss_G: -0.47773244976997375
loss_D: -0.008702069520950317, loss_G: -0.47769883275032043
loss_D: -0.00877364818006754, loss_G: -0.47497087717056274
