In [11]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import wwgan_utils
import torch.nn as nn

In [3]:
class AttributeDict(dict):
    def __getattr__(self, attr):
        return self[attr]
    def __setattr__(self, attr, value):
        self[attr] = value

In [4]:
args = AttributeDict()
args.beta = 0.5
args.epoch = 10

In [26]:
def shift_diamond(X, a, b, c=0, d=0, dim=0, diag=False):
    """
    Input batched images of size (batch, nc, im_height, im_width)
    Applys operation X[i]*a -X[i+1]*b for the dimension chosen dim = 0 horizontal, dim = 1 vertical, dim =2 betwween channels
    """
    if diag:
        filt = torch.FloatTensor(
            [a, b, c, d] + 12 * [0] + [a, b, c, d] + 12 * [0] + [a, b, c, d])
        conv1 = nn.Conv2d(3, 3, (2, 2), stride=1, bias=False)
        filt = filt.view(3, 3, 2, 2)
        conv1.weight.data = filt

    else:
        if dim == 0:
            filt = torch.FloatTensor([a, b, c, 0, 0, 0, 0, 0, 0, 0, 0, 0, a, b, c, 0, 0, 0, 0, 0, 0, 0, 0, 0, a, b, c])
            conv1 = nn.Conv2d(3, 3, (1, 3), stride=1, bias=False)
            filt = filt.view(3, 3, 1, 3)
            conv1.weight.data = filt

        elif dim == 1:
            filt = torch.FloatTensor([a, b, c, 0, 0, 0, 0, 0, 0, 0, 0, 0, a, b, c, 0, 0, 0, 0, 0, 0, 0, 0, 0, a, b, c])
            conv1 = nn.Conv2d(3, 3, (3, 1), stride=1, bias=False)
            filt = filt.view(3, 3, 3, 1)
            conv1.weight.data = filt

        elif dim == 2:
            filt = torch.FloatTensor([a, b, 0, a, 0, b, 0, a, b])
            conv1 = nn.Conv2d(3, 3, (1, 1), stride=1, bias=False)
            filt = filt.view(3, 3, 1, 1)
            conv1.weight.data = filt

    device = torch.device("cuda:4" if (torch.cuda.is_available()) else "cpu")
    conv1.to(device)
    conv1.requires_grad = False
    return conv1(X)

In [25]:
def shift_diamond_depthwise(X, a, b, c=0, d=0, dim=0, diag=False):
    """
    Input batched images of size (batch, nc, im_height, im_width)
    Applys operation X[i]*a -X[i+1]*b for the dimension chosen dim = 0 horizontal, dim = 1 vertical, dim =2 betwween channels
    """
    if diag:
        filt = torch.FloatTensor(
            [a, b, c, d] * 3)
        conv1 = nn.Conv2d(3, 3, (2, 2), groups=3, stride=1, bias=False)
        filt = filt.view(3, 1, 2, 2)
        conv1.weight.data = filt

    else:
        if dim == 0:
            filt = torch.FloatTensor([a, b, c]*3)
            conv1 = nn.Conv2d(3, 3, (1, 3), groups=3, stride=1, bias=False)
            filt = filt.view(3, 1, 1, 3)
            conv1.weight.data = filt

        elif dim == 1:
            filt = torch.FloatTensor([a, b, c]*3)
            conv1 = nn.Conv2d(3, 3, (3, 1), stride=1, groups=3, bias=False)
            filt = filt.view(3, 1, 3, 1)
            conv1.weight.data = filt

        elif dim == 2:
            filt = torch.FloatTensor([a, b, 0, a, 0, b, 0, a, b])
            conv1 = nn.Conv2d(3, 3, (1, 1), stride=1, bias=False)
            filt = filt.view(3, 3, 1, 1)
            conv1.weight.data = filt

    device = torch.device("cuda:4" if (torch.cuda.is_available()) else "cpu")
    conv1.to(device)
    conv1.requires_grad = False
    return conv1(X)

## Test build_L operation

In [19]:
batch= 16
channels = 3 
h = 64
w = 64
X = torch.randn(batch, channels, h,w, device="cuda:4")
a=1
b = -2
dim = 0
diag=True

Y = shift_diamond_depthwise(X, a, b, c=0,d=0, dim=dim, diag=diag)
Z = shift_diamond(X, a, b, c=0, d=0, dim=dim, diag= diag)

In [18]:
Y.shape

torch.Size([16, 3, 63, 63])

In [20]:
Z.shape

torch.Size([16, 3, 63, 63])

In [27]:
def test_depthwise(X, a, b, c,d, dim=dim, diag=diag):
    with torch.no_grad():
        Y = shift_diamond_depthwise(X, a, b, c=0,d=0, dim=dim, diag=diag)
        Z = shift_diamond(X, a, b, c=0, d=0, dim=dim, diag= diag)
        return torch.sum(torch.flatten((Y-Z)**2))

In [31]:
batches = [16,32]
dims = [0,1,2]
diags =[True, False]
a_s = [1, 0,-1]
bs = [1,0, -3]

total_error = 0
for batch in batches:
    for dim in dims:
        for diag in diags:
            for a in a_s:
                for b in bs:
                    X = torch.randn(batch, channels, h, w, device="cuda:4")
                    total_error+= test_depthwise(X, a,b,b,0,dim=dim, diag =diag)

print(total_error)

tensor(0., device='cuda:4')


## Looking at some generated images

In [20]:
file_path ='logs/2020-04-22--10-05-40_log/WGAN_data2020-04-22--10-05-40trial.p'

import os
os.listdir('logs')

import pickle
with open(file_path, 'rb') as f:
    A = pickle.load(f)

A.keys()

A['FID_scores']
os.listdir('logs/2020-04-22--10-05-40_log/')
images = A['img_list']

import numpy as np
import matplotlib.pyplot as plt
for image in images:
    plt.figure(figsize=(16,16))
    plt.imshow(np.transpose(image, (1,2,0)))
    plt.show()

dict_keys(['img_list', 'G_losses', 'D_losses', 'GP_losses', 'FID_scores', 'grad_list', 'wgrad_list', 'grad_sum_list', 'epoch_time'])