<a href="https://colab.research.google.com/github/amansyayf/2016-solar_project/blob/master/notebooks/models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision.models import vgg19
import torchvision.transforms as transforms

from PIL import Image

**Create Generation Network**

In [None]:
class ConvLayer(nn.Module):
  def __init__(self, in_c, out_c, kernel_size):
    super().__init__()
    pad = int(np.floor(kernel_size/2))
    self.conv = nn.Conv2d(in_c, out_c, kernel_size = kernel_size, stride = 1, padding = pad)
  def forward(self, x):
    return self.conv(x)

In [None]:
class Bottleneck(nn.Module):
  def __init__(self, in_c, out_c, kernel_size = 3, stride=1):
    super().__init__()
    self.in_c = in_c
    self.out_c = out_c
    self.kernel_size = kernel_size
    self.identity_block = nn.Sequential(
        ConvLayer(in_c, out_c//4, kernel_size=1),
        nn.InstanceNorm2d(out_c//4),
        nn.ReLU(),
        ConvLayer(out_c//4, out_c//4, kernel_size),
        nn.InstanceNorm2d(out_c//4),
        nn.ReLU(),
        ConvLayer(out_c//4, out_c, kernel_size=1),
        nn.InstanceNorm2d(out_c),
        nn.ReLU()
    )

  def residual(self, x):
    if self.in_c == self.out_c:
      return x
    else:
      layer = nn.Sequential(
          ConvLayer(self.in_c, self.out_c, kernel_size=1),
          nn.InstanceNorm2d(self.out_c)
          )
      return layer(x)

  def forward(self, x):
    out = self.identity_block(x)
    residual_x = self.residual(x)
    out =+ residual_x
    out = F.relu(out)
    return out



In [None]:
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor, mode='bilinear'):
        super().__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
        self.norm = nn.InstanceNorm2d(out_channels)

    def forward(self, x):
        out = self.conv(x)
        out = F.interpolate(out, scale_factor=self.scale_factor, mode=self.mode, align_corners=False)
        out = self.norm(out)
        out = F.relu(out)
        return out

In [None]:
def upsample(scale_factor):
    return nn.Upsample(scale_factor=scale_factor, mode='bilinear')

In [None]:
class HRNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer1_1 = Bottleneck(3, 16)

    self.layer2_1 = Bottleneck(16, 32)
    self.downsample2_1 = nn.Conv2d(16, 32, kernel_size=3, stride = 2, padding=1)

    self.layer3_1 = Bottleneck(32, 32)
    self.layer3_2 = Bottleneck(32, 32)
    self.downsample3_1 = nn.Conv2d(32, 32, kernel_size=3, stride = 2, padding=1)
    self.downsample3_2 = nn.Conv2d(32, 32, kernel_size=3, stride = 4, padding=1)
    self.downsample3_3 = nn.Conv2d(32, 32, kernel_size=3, stride = 2, padding=1)

    self.layer4_1 = Bottleneck(64, 64)
    self.layer5_1 = Bottleneck(192, 64)
    self.layer6_1 = Bottleneck(64, 32)
    self.layer7_1 = Bottleneck(32, 16)
    self.layer8_1 = nn.Conv2d(16, 3, kernel_size=3, stride = 1, padding=1)

  def forward(self, x):
    map1_1 = self.layer1_1(x)

    map2_1 = self.layer2_1(map1_1)
    map2_2 = self.downsample2_1(map1_1)

    map3_1 = torch.cat((self.layer3_1(map2_1), upsample(map2_2, 2)), 1)
    map3_2 = torch.cat((self.downsample3_1(map2_1), self.layer3_2(map2_2)), 1)
    map3_3 = torch.cat((self.downsample3_2(map2_1), self.downsample3_3(map2_2)), 1)

    map4_1 = torch.cat((self.layer4_1(map3_1), upsample(map3_2, 2), upsample(map3_3, 4)), 1)

    out = self.layer5_1(map4_1)
    out = self.layer6_1(out)
    out = self.layer7_1(out)
    out = self.layer8_1(out)

    return out

**Create utility functiions**

In [None]:
def image_loading(path, size=None):
  img = Image.open(path)

  if size is not None:
    img = img.resize((size, size))

  transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
  ])

  img = transform(img)
  img = img.unsqueeze(0)
  return img

In [None]:
def im_convert(img):

    img = img.to('cpu').clone().detach()
    img = img.numpy().squeeze(0)
    img = img.transpose(1, 2, 0)
    img = img * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    img = img.clip(0, 1)
    return img

In [None]:
def get_features(img, model, layers=None):

    if layers is None:
        layers = {
            '0': 'conv1_1',   # style layer
            '5': 'conv2_1',   # style layer
            '10': 'conv3_1',  # style layer
            '19': 'conv4_1',  # style layer
            '28': 'conv5_1',  # style layer

            '21': 'conv4_2'   # content layer
        }

    features = {}
    x = img
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x

    return features


In [None]:
def get_gram_matrix(img):

    b, c, h, w = img.size()
    img = img.view(b*c, h*w)
    gram = torch.mm(img, img.t())
    return gram