In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
from torchvision import transforms
import torchvision.datasets as dset
import torchvision.utils as vutils
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt
from torch import sigmoid
import torch.nn.functional as F

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip -q '/content/drive/My Drive/coco2017.zip' -d '/content/coco2017'
!unzip -q '/content/drive/My Drive/coco2017_cart.zip' -d '/content/coco2017_cart'

In [None]:
!ls

coco2017  coco2017_cart  drive	sample_data


In [None]:
# foldery ze zdjeciami
dataroot_cart = "/content/coco2017_cart/"
dataroot_real = "/content/coco2017/"

workers = 0

batch_size = 32

image_size = 256

lr = 0.0002

beta1 = 0.5

# dla ngpu 0 device = 'cpu'
ngpu = 1

In [None]:
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
print(torch.cuda.get_device_name())
print(device)

Tesla T4
cuda:0


In [None]:
transformer = transforms.Compose([
    transforms.CenterCrop(256),
    transforms.ToTensor()
])

# Obrazki komiksowe
dataset_cart = dset.ImageFolder(root=dataroot_cart, transform=transformer)

dataloader_cart = torch.utils.data.DataLoader(dataset_cart, batch_size=batch_size,
                                            shuffle=False, num_workers=workers)

# Obrazki prawdziwe
dataset_real = dset.ImageFolder(root=dataroot_real, transform=transformer)

dataloader_real = torch.utils.data.DataLoader(dataset_real, batch_size=batch_size,
                                            shuffle=False, num_workers=workers)

In [None]:
class res_block(nn.Module):
  def __init__(self):
    super(res_block, self).__init__()
    self.conv_1 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
    self.conv_2 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
    self.norm_1 = nn.BatchNorm2d(256)
    self.norm_2 = nn.BatchNorm2d(256)

  def forward(self, x):
    output = self.norm_2(self.conv_2(F.relu(self.norm_1(self.conv_1(x)))))
    return output + x

class Generator(nn.Module):
    def __init__(self):
      super(Generator, self).__init__()
      self.conv_1 = nn.Conv2d(3, 64, 7, padding=3)
      self.norm_1 = nn.BatchNorm2d(64)
      self.conv_2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
      self.conv_3 = nn.Conv2d(128, 128, 3, padding=1)
      self.norm_2 = nn.BatchNorm2d(128)
      self.conv_4 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
      self.conv_5 = nn.Conv2d(256, 256, 3, padding=1)
      self.norm_3 = nn.BatchNorm2d(256)

      residual_blocks = []
      for l in range(8):
        residual_blocks.append(res_block())
      self.res = nn.Sequential(*residual_blocks)

      self.conv_6 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1)
      self.conv_7 = nn.ConvTranspose2d(128, 128, 3, padding=1)
      self.norm_4 = nn.BatchNorm2d(128)
      self.conv_8 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
      self.conv_9 = nn.ConvTranspose2d(64, 64, 3, padding=1)
      self.norm_5 = nn.BatchNorm2d(64)
      self.conv_10 = nn.Conv2d(64, 3, 7, padding=3)

    def forward(self, x):
      x = F.relu(self.norm_1(self.conv_1(x)))
      x = F.relu(self.norm_2(self.conv_3(self.conv_2(x))))
      x = F.relu(self.norm_3(self.conv_5(self.conv_4(x))))
      x = self.res(x)
      x = F.relu(self.norm_4(self.conv_7(self.conv_6(x))))
      x = F.relu(self.norm_5(self.conv_9(self.conv_8(x))))

      return sigmoid(self.conv_10(x))

In [None]:
net_label = "35"
path_nets = '/content/drive/My Drive/nets_cart/'

netG = Generator().to(device)
netG.load_state_dict(torch.load(path_nets + "netG" + net_label + "_state"))

In [None]:
def print_photos(phot, phot_cart=None, nrow=8):
    if phot_cart is None:
        plt.imshow(np.transpose(vutils.make_grid(phot, normalize=True, nrow=nrow).cpu(), (1, 2, 0)))
    else:
        plt.imshow(np.transpose(vutils.make_grid([*phot,*phot_cart], normalize=True, nrow=nrow).cpu(), (1, 2, 0)))
    plt.show()


def save_net(net, path):
    #torch.save(net, path)
    torch.save(net.state_dict(), path + "_state")



In [None]:
criterion = nn.BCELoss()

optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
n_epoch = 51
net_number = 31
save_period = 75
print_period = 75

for epoch in range(n_epoch):

    n_iters = len(dataloader_cart) if len(dataloader_cart) < len(dataloader_real) else len(dataloader_real)

    for (i, cart_batch), real_batch in zip(enumerate(dataloader_cart, 1),
                                    iter(dataloader_real)):

        # ladowanie komiksowych obrazkow
        cart_images = cart_batch[0].to(device)
        cart_size = cart_images.size(0)

        # ladowanie prawdziwych zdjec
        real_images = real_batch[0].to(device)
        real_size = real_images.size(0)

        ####################################
        # Generator training

        optimizerG.zero_grad()

        output = netG(real_images)

        errG = criterion(output, cart_images)
        errG.backward()

        optimizerG.step()

        if i % print_period == 0 or i == 2:
            print_photos(real_images[:8], output[:8])
        if i % 20 == 0:
            print("epoch - ", epoch, ": ", i, "/", n_iters, "   errG - ", errG.item())
        if (i == n_iters - 1) or i % save_period == 0:
            save_net(netG, path_nets + "netG" + str(net_number))
            net_number += 1


Output hidden; open in https://colab.research.google.com to view.

In [None]:
#del real_images
#del cart_images
#del output
torch.cuda.empty_cache()
optimizerG.zero_grad()

In [None]:
drive.flush_and_unmount()