In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets,transforms
from torch.utils.data import DataLoader

In [2]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [3]:
batch_size=64
data = datasets.CIFAR10(root='./data/',download=False,transform=transforms.ToTensor())
data_loader = DataLoader(dataset=data,batch_size=batch_size,shuffle=True)

In [4]:
latent_size = 64
dfeatures = 64
gfeatures = 64
nchannels = 3
DNet = nn.Sequential(nn.Conv2d(nchannels,dfeatures,kernel_size=4,stride=2,padding=1),
                     nn.LeakyReLU(0.2),
                     nn.Conv2d(dfeatures,2*dfeatures,kernel_size=4,stride=2,padding=1,bias=False),
                     nn.BatchNorm2d(2*dfeatures),
                     nn.LeakyReLU(0.2),
                     nn.Conv2d(2*dfeatures,4*dfeatures,kernel_size=4,stride=2,padding=1,bias=False),
                     nn.BatchNorm2d(4*dfeatures),
                     nn.LeakyReLU(0.2),
                     nn.Conv2d(4*dfeatures,1,kernel_size=4))
GNet = nn.Sequential(nn.ConvTranspose2d(latent_size,4*gfeatures,kernel_size=4,bias=False),
                     nn.BatchNorm2d(4*gfeatures),
                     nn.ReLU(),
                     nn.ConvTranspose2d(4*gfeatures,2*gfeatures,kernel_size=4,stride=2,padding=1,bias=False),
                     nn.BatchNorm2d(2*gfeatures),
                     nn.ReLU(),
                     nn.ConvTranspose2d(2*gfeatures,gfeatures,kernel_size=4,stride=2,padding=1,bias=False),
                     nn.BatchNorm2d(gfeatures),
                     nn.ReLU(),
                     nn.ConvTranspose2d(gfeatures,nchannels,kernel_size=4,stride=2,padding=1))                     

In [5]:
criterion = nn.BCEWithLogitsLoss()
doptimizer = optim.Adam(DNet.parameters(),lr=0.001,betas=(0,0.99))
goptimizer = optim.Adam(GNet.parameters(),lr=0.001,betas=(0,0.99))
DNet.to(device)
GNet.to(device)
criterion.to(device)

BCEWithLogitsLoss()

In [6]:
def weights_init(m):
    if type(m) in [nn.Conv2d,nn.ConvTranspose2d]:
        nn.init.xavier_normal_(m.weight)
    elif type(m)==nn.BatchNorm2d:
        nn.init.constant_(m.bias,0)
        nn.init.normal_(m.weight,1.0,0.2)
DNet.apply(weights_init)
GNet.apply(weights_init)

Sequential(
  (0): ConvTranspose2d(64, 256, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (1): ReLU()
  (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (4): ReLU()
  (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (7): ReLU()
  (8): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (9): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)

In [8]:
num_epochs = 10
for epoch in range(num_epochs):
    for batch_idx, data in enumerate(data_loader):
        # DNet
        ## positive
        real_img,_=data
        batch_size = real_img.size(0)
        pos_label = torch.ones(batch_size)
        real_img, pos_label = real_img.to(device), pos_label.to(device)
        preds = DNet(real_img)
        outputs = preds.view(-1)
        dloss_real = criterion(outputs,pos_label)
        dmean_real = preds.sigmoid().mean()
        ## negative
        neg_label = torch.zeros(batch_size)
        noises = torch.randn(batch_size,latent_size,1,1)
        noises = noises.to(device)
        fake_img = GNet(noises)
        fake_img_d = fake_img.detach()
        fake_img_d, neg_label = fake_img_d.to(device), neg_label.to(device)
        preds = DNet(fake_img_d)
        outputs = preds.view(-1)
        dloss_fake = criterion(outputs,neg_label)
        dmean_fake = preds.sigmoid().mean()
        dloss = dloss_fake+dloss_real
        doptimizer.zero_grad()
        dloss.backward()
        doptimizer.step()
        # GNet
        pos_label = torch.ones(batch_size)
        fake_img, pos_label = fake_img.to(device), pos_label.to(device)
        preds = DNet(fake_img)
        outputs = preds.view(-1)
        gloss = criterion(outputs,pos_label)
        gmean = preds.sigmoid().mean()
        goptimizer.zero_grad()
        gloss.backward()
        goptimizer.step()
        if (batch_idx+1)%100==0:
            print(f'Epoch {epoch+1}: {(batch_idx+1)*batch_size}/{len(data_loader)*batch_size}\
                    DLoss={dloss.item()} GLoss={gloss.item()} TP={dmean_real} FP={dmean_fake/dmean_real}')

Epoch 1: 6400/50048                    DLoss=0.15704606473445892 GLoss=2.8967294692993164 TP=0.9985477924346924 FP=0.13908635079860687
Epoch 1: 12800/50048                    DLoss=0.018363824114203453 GLoss=4.7067389488220215 TP=0.9951010942459106 FP=0.01341303065419197
Epoch 1: 19200/50048                    DLoss=1.5726850032806396 GLoss=0.36824649572372437 TP=0.2679170072078705 FP=0.25432127714157104
Epoch 1: 25600/50048                    DLoss=0.2673996686935425 GLoss=2.382692813873291 TP=0.9601961374282837 FP=0.20907986164093018
Epoch 1: 32000/50048                    DLoss=1.908254623413086 GLoss=0.2461148202419281 TP=0.2451159507036209 FP=0.6441898345947266
Epoch 1: 38400/50048                    DLoss=0.39962679147720337 GLoss=3.0069031715393066 TP=0.8208999633789062 FP=0.16941089928150177
Epoch 1: 44800/50048                    DLoss=0.5236822366714478 GLoss=2.851461410522461 TP=0.9175747632980347 FP=0.3593634068965912
Epoch 2: 6400/50048                    DLoss=1.186704397

Epoch 9: 44800/50048                    DLoss=1.307892084121704 GLoss=0.6049456000328064 TP=0.37242749333381653 FP=0.49521875381469727
Epoch 10: 6400/50048                    DLoss=1.0001047849655151 GLoss=3.3033580780029297 TP=0.8246070146560669 FP=0.6121111512184143
Epoch 10: 12800/50048                    DLoss=1.2356964349746704 GLoss=0.8209010362625122 TP=0.4245642423629761 FP=0.5328812599182129
Epoch 10: 19200/50048                    DLoss=0.8028992414474487 GLoss=1.76851487159729 TP=0.6721650958061218 FP=0.43780067563056946
Epoch 10: 25600/50048                    DLoss=1.445138931274414 GLoss=0.9265451431274414 TP=0.40053874254226685 FP=0.7606292963027954
Epoch 10: 32000/50048                    DLoss=1.1730576753616333 GLoss=0.7719698548316956 TP=0.4203357696533203 FP=0.3666422665119171
Epoch 10: 38400/50048                    DLoss=1.2327630519866943 GLoss=2.8511481285095215 TP=0.7189712524414062 FP=0.7518978714942932
Epoch 10: 44800/50048                    DLoss=1.05274939