In [1]:
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import os
import torch.nn as nn
from dataset import SatelliteDataset
from torch.utils.data import DataLoader

In [2]:
def style_loss_fn(phi1, phi2):
    batch_size, c, h, w = phi1.shape
    psi1 = phi1.reshape((batch_size, c, w*h))
    psi2 = phi2.reshape((batch_size, c, w*h))
    
    gram1 = torch.matmul(psi1, torch.transpose(psi1, 1, 2)) / (c*h*w)
    gram2 = torch.matmul(psi2, torch.transpose(psi2, 1, 2)) / (c*h*w)
    # as described in johnson et al.
    return torch.sum(torch.norm(gram1 - gram2, p = "fro", dim=(1,2))) 


In [3]:
vgg_model = torch.hub.load('pytorch/vision:v0.9.0', 'vgg16', pretrained=True)
relu3_3 = torch.nn.Sequential(*vgg_model.features[:16])
relu4_3 = torch.nn.Sequential(*vgg_model.features[:23])

Using cache found in C:\Users\hugih/.cache\torch\hub\pytorch_vision_v0.9.0


In [4]:
normalize =  transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                     std=[0.5, 0.5, 0.5])
toTensor = transforms.ToTensor()

In [10]:
data_dir = "../../data/train"

In [11]:
rgb_files = os.listdir(os.path.join(data_dir, "rgb"))

In [12]:
def get_img(idx):
    img_fn = os.path.join(data_dir, "rgb", rgb_files[idx])
    with Image.open(img_fn) as img:
        img = toTensor(img)[:3,:,:]
        img = normalize(img)
    img = img.reshape((1, 3, 256, 256))
    return img

In [13]:
img1 = get_img(10)
img2 = get_img(2)

In [14]:
print(torch.max(img1))
print(torch.max(img2))

tensor(1.)
tensor(1.)


In [15]:
img1

tensor([[[[-0.6706, -0.6863, -0.6863,  ..., -0.8588, -0.8745, -0.8353],
          [-0.6863, -0.6941, -0.7020,  ..., -0.8980, -0.8824, -0.8745],
          [-0.6784, -0.6784, -0.6863,  ..., -0.8588, -0.8588, -0.8667],
          ...,
          [-0.7255, -0.7333, -0.7176,  ..., -0.8824, -0.8824, -0.8824],
          [-0.7412, -0.7255, -0.7098,  ..., -0.8824, -0.8824, -0.8745],
          [-0.7569, -0.7333, -0.7333,  ..., -0.8824, -0.8824, -0.8902]],

         [[-0.5451, -0.5373, -0.5608,  ..., -0.6392, -0.6157, -0.6157],
          [-0.5373, -0.5529, -0.5529,  ..., -0.7098, -0.7020, -0.6863],
          [-0.5373, -0.5451, -0.5373,  ..., -0.6627, -0.6706, -0.6627],
          ...,
          [-0.6157, -0.6078, -0.6078,  ..., -0.7569, -0.7569, -0.7647],
          [-0.6392, -0.6314, -0.6078,  ..., -0.7804, -0.8039, -0.7490],
          [-0.6314, -0.6392, -0.5843,  ..., -0.7725, -0.7647, -0.7725]],

         [[-0.7412, -0.7412, -0.7176,  ..., -0.8980, -0.8902, -0.8667],
          [-0.7176, -0.7412, -

In [16]:
phi1 = relu3_3(img1)
phi2 = relu3_3(img2)

In [20]:
mse = nn.MSELoss()

In [17]:
style_loss_fn(phi1, phi2)

tensor(0.4954, grad_fn=<SumBackward0>)

In [21]:
mse(phi1, phi2)

tensor(6.3471, grad_fn=<MseLossBackward>)

In [24]:
ds = SatelliteDataset("../../data/train")
loader = DataLoader(ds, 8)

In [31]:
from generator import Generator

In [32]:
gen = Generator()

In [34]:
for rgb_a, rgb_b, rgb_ab, lc_a, lc_b, lc_b_mask, lc_ab, masked_areas in loader:
    
    fake_img = gen(rgb_a, lc_a, lc_b_mask)
    feature_ab = relu3_3(rgb_ab)
    feature_fake = relu3_3(fake_img)
    
    st_loss = style_loss_fn(feature_ab, feature_fake)
    print("style loss",st_loss)
    print("mse loss", mse(feature_ab, feature_fake))
    st_loss.backward()
    
    break

style loss tensor(17.8600, grad_fn=<SumBackward0>)
mse loss tensor(8.4848, grad_fn=<MseLossBackward>)
