In [110]:
import torch.nn as nn
import torch.nn.functional as F

def max_singular_value(weight, u, Ip):
    assert(Ip >= 1)
    
    _u = u
    for _ in range(Ip):
        _v = F.normalize(torch.mm(_u, weight), p=2, dim=1)
        _u = F.normalize(torch.mm(_v, weight.transpose(0,1)), p=2, dim=1)
    sigma = torch.sum(F.linear(_u, weight.transpose(0,1)) * _v)
    return sigma, _u

class SNLinear(nn.Linear):
    
    def __init__(self, in_features, out_features):
        super(SNLinear, self).__init__(
            in_features, out_features
        )
        self.Ip = 1
        self.u = torch.randn(1, out_features)
        
    @property
    def W_bar(self):
        sigma, u = max_singular_value(self.weight.data, self.u, self.Ip)
        self.u = u
        return self.weight / sigma

    def forward(self, x):
        return F.linear(x, self.W_bar, self.bias)
    
class SNConv2d(nn.Conv2d):
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(SNConv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias
        )
        self.Ip = 1
        self.u = torch.randn(1, out_channels)
        
    @property
    def W_bar(self):
        w = self.weight.data
        w = w.view(w.shape[0], -1)
        
        sigma, u = max_singular_value(w, self.u, self.Ip)
        self.u = u
        return self.weight / sigma
    
    def forward(self, x):
        return F.conv2d(
            x, 
            self.W_bar,
            bias=self.bias,
            stride=self.stride, 
            padding=self.padding, 
            dilation=self.dilation,
            groups=self.groups
        ) 

In [102]:
lin = SNConv2d(100, 120, 3)



w.shape

torch.Size([120, 900])

In [111]:
lin = SNConv2d(14, 15, 3)

x = torch.randn(12, 14, 10, 10)
lin(x).shape

torch.Size([12, 15, 8, 8])

In [92]:
for param in SNLinear(10, 15).parameters():
    print(param)

Parameter containing:
tensor([[-0.1344, -0.0797, -0.2561,  0.1094, -0.0461,  0.2431, -0.1461, -0.3131,
          0.2215, -0.1711],
        [ 0.0188,  0.2529, -0.2751,  0.3099,  0.0099, -0.0387, -0.3068,  0.1377,
          0.2275, -0.1156],
        [ 0.0083,  0.0058, -0.0950,  0.0093, -0.2357, -0.0082,  0.0384, -0.0678,
         -0.0844, -0.2534],
        [-0.1961,  0.0689,  0.1861,  0.2674,  0.0690,  0.0026, -0.1328, -0.0788,
         -0.0578, -0.0817],
        [-0.1379,  0.3123,  0.0666,  0.2435, -0.2865,  0.2546, -0.1945,  0.1468,
         -0.0049,  0.0315],
        [-0.2007, -0.3100,  0.3121,  0.1752, -0.0687,  0.1491,  0.0728,  0.0833,
          0.1366,  0.3069],
        [-0.1039,  0.1927, -0.0085,  0.2404,  0.3003,  0.3099,  0.1758, -0.0459,
          0.1812,  0.0498],
        [-0.2133,  0.2265, -0.0884,  0.0437,  0.1080, -0.2533, -0.1086, -0.0751,
         -0.0074,  0.2466],
        [-0.1383, -0.0812, -0.0492,  0.2006,  0.2030, -0.0799, -0.0996,  0.2556,
         -0.1867, -0.1961

array([7.5015874, 5.1000624, 4.716616 , 4.321139 , 3.789548 , 3.5263743,
       3.2069032, 1.9736412, 1.4177839, 0.9812959], dtype=float32)

In [29]:
def powermethod_iter(W, u, Ip):
    assert(W.shape[0] = u.shape[0])
    v_prime = torch.mm(u, W)
    v_prime = v_prime / torch.norm(v_prime)
    
    u_prime = torch.mm(W, v)
    
    sigma = u_prime.mm(W).mm(v_prime)
    
    return sigma, u_prime, v_prime

W = torch.randn(10, 15)

u = torch.randn(1, 15)
powermethod_iter(W, u, 1)

RuntimeError: size mismatch, m1: [1 x 15], m2: [10 x 15] at /Users/soumith/b101/2019_02_04/wheel_build_dirs/wheel_3.6/pytorch/aten/src/TH/generic/THTensorMath.cpp:940

In [117]:
import torchvision
import torchvision.transforms as transforms

########################################################################
# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1].

def custom_preprocess(tensor):
    transformed_tensor = 2. * tensor - 1.
    transformed_tensor += torch.rand(*transformed_tensor.shape) / 128.
    return transformed_tensor
    
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        custom_preprocess
    ]
)

trainset = torchvision.datasets.CIFAR10(root='Users/kylesargent/Desktop/', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)



Files already downloaded and verified


In [45]:
import logging
import sys
logging.basicConfig(filename=os.path.join('/Users/kylesargent/', 'training.log'), level=logging.DEBUG)

handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
root.addHandler(handler)

NameError: name 'root' is not defined

In [15]:
batch_size = 16
epochs = 10
dis_iters = 1

from cifar10_models import Generator, Discriminator

G = Generator()
D = Discriminator()

g_optim = torch.optim.Adam(G.parameters(), lr=.001)
d_optim = torch.optim.Adam(D.parameters(), lr=.001)

for epoch in range(epochs):
    for i in range(train_data.shape[0] // batch_size):
        z = Variable(sample_z(batch_size))
        x_real = Variable(train_data[batch_size * i : batch_size * (i+1)])
        
        for k in range(dis_iters):
            # train discriminator
            x_fake = G(z)
            dis_fake = D(x_fake.detach())
            dis_real = D(x_real)

            loss = dis_loss(dis_fake, dis_real)
            loss.backward()
            d_optim.step()

            # train generator
            if k==0:
                g_optim.zero_grad()
                dis_fake = D(x_fake)

                loss = gen_loss(dis_fake) 
                loss.backward()
                g_optim.step()
        

ImportError: cannot import name 'Generator'

In [21]:
from cifar10_models import Cifar10Generator
from torch.distributions.normal import Normal
from torch.autograd import Variable

def sample_z(batch_size):
    n = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
    return n.sample((batch_size, 128)).squeeze(2)

G = Cifar10Generator()
z = (sample_z(batch_size))

G(z)

tensor([[[[ 1.3192e-01,  1.2600e-01,  3.0534e-01,  ...,  4.8395e-01,
            4.2447e-01,  3.3980e-01],
          [-9.0867e-02, -3.8691e-01, -9.0201e-02,  ...,  2.7825e-01,
           -4.6404e-02,  1.0383e-01],
          [ 5.2356e-02, -3.8307e-01, -1.0839e-01,  ...,  5.4389e-01,
            2.9140e-01,  3.9991e-01],
          ...,
          [ 5.2455e-01,  6.5654e-01,  6.7763e-01,  ...,  2.7052e-01,
            3.3309e-01,  3.5444e-01],
          [ 3.9760e-01,  4.5689e-01,  4.7463e-01,  ...,  2.4466e-01,
            2.9873e-01,  2.6779e-01],
          [ 1.4250e-01,  1.2857e-01,  2.6162e-01,  ...,  1.4585e-01,
            2.4795e-01, -4.9396e-02]],

         [[-3.5721e-01, -1.5552e-01, -1.9746e-01,  ..., -2.9610e-01,
           -2.7172e-01,  6.7485e-02],
          [-4.7579e-02, -8.0179e-02, -1.6852e-01,  ..., -2.1015e-01,
           -1.8315e-01, -1.6189e-02],
          [ 2.8562e-01,  2.3068e-01,  1.7587e-01,  ...,  9.8998e-02,
           -1.4995e-01, -3.6112e-02],
          ...,
     