In [1]:
import itertools
import torch
import torch.nn as nn
import torch.optim as optim

from PIL import Image
from torchvision import transforms
from torchvision.utils import make_grid
from torchsummary import summary

from matplotlib import pyplot as plt

%matplotlib inline

device = "cuda" if torch.cuda.is_available() else "cpu"
print("using device: %s" % device)

using device: cuda


In [3]:
# prepare data

from appleorange import AppleOrangeDataset
from torch.utils.data import DataLoader

IMAGE_SIZE = 256

transform = transforms.Compose([
    transforms.Resize(int(IMAGE_SIZE * 1.33)),
    transforms.RandomCrop(IMAGE_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

apple_train_dir = r"E:\datasets\apple-orange\apples_train"
orange_train_dir = r"E:\datasets\apple-orange\oranges_train"

apple_test_dir = r"E:\datasets\apple-orange\apples_test"
orange_test_dir = r"E:\datasets\apple-orange\oranges_test"

train_ds = AppleOrangeDataset(apple_train_dir, orange_train_dir, transforms=transform, device=device)
test_ds = AppleOrangeDataset(apple_test_dir, orange_test_dir, transforms=transform, device=device)

train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, collate_fn=train_ds.collate_fn)
test_dl = DataLoader(test_ds, batch_size=5, shuffle=False, collate_fn=test_ds.collate_fn)

print("train data: %d, test data: %d" % (len(train_ds), len(test_ds)))


train data: 3067, test data: 1021


In [None]:
# prepare network
def weight_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)
    elif classname.find("BatchNorm2d") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)
    else:
        pass # default

class ResBlock(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features)
        )

    def forward(self, xs):
        return self.bolck(xs) + xs

class GeneratorResNet(nn.Module):
    def __init__(self, num_residual_blocks=9):
        super().__init__(self)
        out_features = 64
        channels = 3

        layers = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(channels, out_features, kernel_size=7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU()
        ]


        # Downsampling
        in_features = out_features
        for _ in range(2):
            out_features *= 2
            layers += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU()
            ]
            in_features = out_features
        
        # Residual blocks
        for _ in range(num_residual_blocks):
            layers.append(ResBlock(out_features))
        
        # Upsampling
        for _ in range(2):
            out_features //= 2
            layers += [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=1, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU()
            ]
            in_features = out_features
        
        # Output lay
        layers += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(out_features, channels, kernel_size=7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(layers)
        self.apply(weight_init_normal)

    def forward(self, xs):
        return self.model(xs)

