In [0]:
import torch
import torchvision
import numpy as np
import tqdm

In [0]:
dataset = torchvision.datasets.CIFAR10('./',train=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([.5,.5,.5],[.5,.5,.5])]),download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./cifar-10-python.tar.gz to ./


In [0]:
batch_size=256
num_workers = 4
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [0]:
nsd = 100
ndf = 64
ngf =64
nc = 3 # channels
num_classes = len(dataset.classes)
num_epochs = 50
lr=2e-3
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
img_shape = (32,32)

In [0]:
def get_noise_label(batch_size,nsd=nsd,device=device):
  ns = torch.cat([torch.randn(batch_size,nsd,1,1), torch.stack([torch.eye(num_classes)[np.random.choice(list(range(num_classes)))].unsqueeze(-1).unsqueeze(-1) for _ in range(batch_size)])], dim=1).to(device)
  label = ns[:,nsd:]
  return ns, label

def labels_to_onehots(labels,device=device):
  return torch.stack([torch.eye(num_classes)[label].unsqueeze(-1).unsqueeze(-1) for label in labels]).to(device)

def onehots_to_imgs(onehots,device=device):
  return torch.stack([torch.where(onehot==1, torch.ones(img_shape).to(device),torch.zeros(img_shape).to(device)) for onehot in onehots]).to(device)

In [0]:
torch.eye(num_classes)[4].unsqueeze(1).repeat(1,3).unsqueeze(-1).unsqueeze(-1).shape

torch.Size([10, 3, 1, 1])

In [0]:
onehots = labels_to_onehots(labels,device='cuda:0')
onehots.shape

torch.Size([256, 10, 1, 1])

In [0]:
ones = torch.ones(img_shape)
zeros = torch.zeros(img_shape)

In [0]:
onehots_to_imgs(onehots).shape

torch.Size([256, 10, 32, 32])

In [0]:
torch.stack([torch.where(onehot==1, ones,zeros) for onehot in onehots])[0]

RuntimeError: ignored

In [0]:
onehots_to_imgs(labels_to_onehots(labels))

RuntimeError: ignored

In [0]:
torch.stack([torch.ones(img_shape), torch.zeros(img_shape)]).repeat(3).shape

RuntimeError: ignored

In [0]:
next(iter(dataloader))[0].shape

torch.Size([256, 3, 32, 32])

In [0]:
labels_to_onehots(next(iter(dataloader))[1].to(device)).shape

torch.Size([256, 10, 3, 1, 1])

In [0]:
class Generator(torch.nn.Module):
  def __init__(self,nsd,num_classes):
    super(Generator, self).__init__()
    # 1x1 -> 4x4
    self.block1 = torch.nn.Sequential(
        torch.nn.ConvTranspose2d(nsd+num_classes,ngf*4,4,1,0),
        torch.nn.BatchNorm2d(ngf*4),
        torch.nn.ReLU(inplace=True)
    )
    # 4x4 -> 8x8
    self.block2 = torch.nn.Sequential(
        torch.nn.ConvTranspose2d(ngf*4, ngf*2,4,2,1),
        torch.nn.BatchNorm2d(ngf*2),
        torch.nn.ReLU(True)
    )
    # 8x8 -> 16x16
    self.block3 = torch.nn.Sequential(
        torch.nn.ConvTranspose2d(ngf*2,ngf,4,2,1),
        torch.nn.BatchNorm2d(ngf),
        torch.nn.ReLU(True)
    )
    # 16x16 -> 32x32
    self.block4 = torch.nn.Sequential(
        torch.nn.ConvTranspose2d(ngf,3,4,2,1),
        torch.nn.Tanh()
        # torch.nn.BatchNorm2d(1),
        # torch.nn.ReLU(True)
    )
    # 32x32 -> 64x64
    # self.block5 = torch.nn.Sequential(
        # torch.nn.ConvTranspose2d(ngf,nc,4,2,1),
        # torch.nn.Tanh()
    # )


  def forward(self,x):
    # print(x.shape)
    x = self.block1(x)
    # print(x.shape)
    x = self.block2(x)
    # print(x.shape)
    x = self.block3(x)
    # print(x.shape)
    x = self.block4(x)
    # print(x.shape)
    # x = self.block5(x)
    # print(x.shape)

    return x

In [0]:
class Discriminator(torch.nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()
    # 64x64 -> 32x32
    self.block1= torch.nn.Sequential(
        torch.nn.Conv2d(13, ndf,4,2,1),
        torch.nn.BatchNorm2d(ndf),
        torch.nn.LeakyReLU(0.2, True)
    )
    # 32x32 -> 16x16
    self.block2 = torch.nn.Sequential(
        torch.nn.Conv2d(ndf, ndf*2,4,2,1),
        torch.nn.BatchNorm2d(ndf*2),
        torch.nn.LeakyReLU(0.2, True)
    )
    # 16x16 -> 8x8
    self.block3 = torch.nn.Sequential(
        torch.nn.Conv2d(ndf*2,ndf*4,4,2,1),
        torch.nn.BatchNorm2d(ndf*4),
        torch.nn.LeakyReLU(0.2, True)
    )
    # 8x8 -> 4x4
    self.block4 = torch.nn.Sequential(
        torch.nn.Conv2d(ndf*4,1,4,1,0),
        torch.nn.Sigmoid()
        # torch.nn.BatchNorm2d(ndf*8),
        # torch.nn.LeakyReLU(0.2, True)
    )
    # # 4x4 -> 1x1
    # self.block5 = torch.nn.Sequential(
    #     torch.nn.Conv2d(ndf*8,1,4,1,0),
    #     torch.nn.Sigmoid()
    # )

  def forward(self,x):
    # print(x.shape)
    x = self.block1(x)
    # print(x.shape)
    x = self.block2(x)
    # print(x.shape)
    x = self.block3(x)
    # print(x.shape)
    x = self.block4(x)
    # print(x.shape)
    # x = self.block5(x)
    # print(x.shape)
    return x

In [0]:
x,label = get_noise_label()
# print(x.shape)
x.shape
# label_to_onehot(3).shape

torch.Size([256, 110, 1, 1])

In [0]:
G = Generator(nsd,num_classes)
G.to(device)
D = Discriminator()
D.to(device)

Discriminator(
  (block1): Sequential(
    (0): Conv2d(13, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (block2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (block3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (block4): Sequential(
    (0): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1))
    (1): Sigmoid()
  )
)

In [0]:
# D(G(x))
criterion = torch.nn.BCELoss()
g_optimizer = torch.optim.Adam(G.parameters(),lr=lr)
d_optimizer = torch.optim.Adam(D.parameters(),lr=lr)

In [0]:
next(iter(dataloader))[1].shape

torch.Size([256])

In [0]:
for epoch in range(1,num_epochs+1):
  D.train()
  G.train()
  d_epoch_loss = 0.
  g_epoch_loss = 0.

  d_epoch_imgs = 0.
  g_epoch_imgs = 0.

  for (data, labels) in tqdm.tqdm(dataloader):
    fake_label = torch.ones(data.size(0)).to(device)
    real_label = torch.zeros(data.size(0)).to(device)
    data.to(device)
    labels.to(device)
    d_epoch_imgs += 2*data.size(0)

    ns,ns_label = get_noise_label(data.size(0))

    d_optimizer.zero_grad()
    fake_images = G(ns)
    fake_G_inputs = torch.cat([fake_images,onehots_to_imgs(ns_label)], dim=1)
    out = D(fake_G_inputs)
    loss = criterion(out, fake_label)
    loss.backward()
    d_epoch_loss += loss.item()

    real_G_inputs = torch.cat([data.to(device), onehots_to_imgs(labels_to_onehots(labels).to(device)).to(device)],dim=1)
    out = D(real_G_inputs)
    loss = criterion(out, real_label)
    loss.backward()
    d_optimizer.step()
    d_epoch_loss += loss.item()
  ns,ns_label = get_noise_label(1024)
  g_epoch_imgs = ns.size(0)
  g_optimizer.zero_grad()
  fake_images = G(ns)
  fake_G_inputs = torch.cat([fake_images,onehots_to_imgs(ns_label)], dim=1)
  out = D(fake_G_inputs)
  real_label = torch.zeros(1024).to(device)

  loss = criterion(out, real_label)
  loss.backward()
  g_optimizer.step()
  g_epoch_loss = loss.item()

  print('Epoch{}: D_Loss:{:.4f} G_Loss:{:.4f}'.format(epoch, d_epoch_loss/d_epoch_imgs, g_epoch_loss/g_epoch_imgs))






  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)



  1%|          | 1/196 [00:00<02:26,  1.34it/s][A[A[A


  1%|          | 2/196 [00:01<01:59,  1.62it/s][A[A[A


  2%|▏         | 3/196 [00:01<01:41,  1.90it/s][A[A[A


  2%|▏         | 4/196 [00:01<01:27,  2.18it/s][A[A[A


  3%|▎         | 5/196 [00:01<01:18,  2.45it/s][A[A[A


  3%|▎         | 6/196 [00:02<01:11,  2.66it/s][A[A[A


  4%|▎         | 7/196 [00:02<01:06,  2.84it/s][A[A[A


  4%|▍         | 8/196 [00:02<01:03,  2.97it/s][A[A[A


  5%|▍         | 9/196 [00:03<01:00,  3.10it/s][A[A[A


  5%|▌         | 10/196 [00:03<00:57,  3.22it/s][A[A[A


  6%|▌         | 11/196 [00:03<00:56,  3.28it/s][A[A[A


  6%|▌         | 12/196 [00:04<00:54,  3.37it/s][A[A[A


  7%|▋         | 13/196 [00:04<00:53,  3.44it/s][A[A[A


  7%|▋         | 14/196 [00:04<00:51,  3.52it/s][A[A[A


  8%|▊         | 15/196 [00:04<00:51,  3.48it/s][A[A[A


  8%|▊    

Epoch1: D_Loss:0.0001 G_Loss:0.0153





  1%|          | 1/196 [00:00<02:26,  1.33it/s][A[A[A


  1%|          | 2/196 [00:01<01:58,  1.64it/s][A[A[A


  2%|▏         | 3/196 [00:01<01:37,  1.97it/s][A[A[A


  2%|▏         | 4/196 [00:01<01:23,  2.30it/s][A[A[A


  3%|▎         | 5/196 [00:01<01:13,  2.61it/s][A[A[A


  3%|▎         | 6/196 [00:02<01:06,  2.88it/s][A[A[A


  4%|▎         | 7/196 [00:02<01:00,  3.10it/s][A[A[A


  4%|▍         | 8/196 [00:02<00:57,  3.25it/s][A[A[A


  5%|▍         | 9/196 [00:02<00:55,  3.38it/s][A[A[A


  5%|▌         | 10/196 [00:03<00:54,  3.44it/s][A[A[A


  6%|▌         | 11/196 [00:03<00:52,  3.52it/s][A[A[A


  6%|▌         | 12/196 [00:03<00:51,  3.56it/s][A[A[A


  7%|▋         | 13/196 [00:03<00:50,  3.63it/s][A[A[A


  7%|▋         | 14/196 [00:04<00:49,  3.65it/s][A[A[A


  8%|▊         | 15/196 [00:04<00:48,  3.70it/s][A[A[A


  8%|▊         | 16/196 [00:04<00:48,  3.70it/s][A[A[A


  9%|▊         | 17/196 [00:05<00:48,  3.71it/

KeyboardInterrupt: ignored

In [0]:
data.shape

torch.Size([256, 3, 32, 32])

In [0]:
fake_images.shape

torch.Size([256, 3, 32, 32])

In [0]:
fake_G_inputs.shape

torch.Size([256, 11, 32, 32])

In [0]:
onehots_to_imgs(labels_to_onehots(labels)).shape

torch.Size([256, 10, 32, 32])