In [51]:
import torch
from torchvision import datasets
import torchvision.transforms as transforms
from torch.autograd import Variable, grad
from torch.autograd.functional import hessian
import numpy as np
import sys
if '/home/cybai/MRGAN/' not in sys.path:
    sys.path.insert(0, '/home/cybai/MRGAN/')
from wgan_gp import Generator, Discriminator

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
latent_dim = 128
n_channels = 3
bsize = 128

In [4]:
generator = Generator(latent_dim=latent_dim, n_channels=n_channels).to(device=device)
discriminator = Discriminator(latent_dim=latent_dim, n_channels=n_channels).to(device=device)

In [5]:
load_gen_path = '/home/cybai/MRGAN/images/202010250916/generator_epoch2400.pth'
load_dis_path = '/home/cybai/MRGAN/images/202010250916/discriminator_epoch2400.pth'

In [59]:
generator.model

Sequential(
  (0): Linear(in_features=128, out_features=8192, bias=True)
  (1): BatchNorm1d(8192, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): View()
  (4): ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
  (5): BatchNorm2d(256, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)
  (6): ReLU(inplace=True)
  (7): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
  (8): BatchNorm2d(128, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True)
  (9): ReLU(inplace=True)
  (10): ConvTranspose2d(128, 3, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
  (11): Tanh()
)

In [71]:
def get_all_params_generator(g):
    all_params = []
    for layer in g.model:
        try:
            if layer.weight.requires_grad:
                all_params.append(layer.weight)
        except:
            print(f"no weight in {layer}")
    return all_params

In [72]:
generator.load_state_dict(torch.load(load_gen_path))
discriminator.load_state_dict(torch.load(load_dis_path))

<All keys matched successfully>

In [73]:
all_params_g = get_all_params_generator(generator)

no weight in ReLU(inplace=True)
no weight in View()
no weight in ReLU(inplace=True)
no weight in ReLU(inplace=True)
no weight in Tanh()


In [74]:
all_params_g

[Parameter containing:
 tensor([[-0.0171,  0.0052,  0.0154,  ...,  0.0961,  0.0477,  0.0512],
         [-0.0334,  0.0604, -0.0841,  ..., -0.0006, -0.0458, -0.0530],
         [ 0.1180,  0.0553, -0.0032,  ...,  0.0214, -0.0004,  0.0981],
         ...,
         [ 0.1608,  0.0268, -0.0499,  ...,  0.0679, -0.0082,  0.0983],
         [-0.0181,  0.0216, -0.1130,  ...,  0.0251, -0.0307,  0.0127],
         [ 0.0758,  0.0538, -0.0034,  ...,  0.0782, -0.0335,  0.0954]],
        device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([1.0103, 0.8993, 0.9315,  ..., 1.0319, 0.8533, 1.0921], device='cuda:0',
        requires_grad=True),
 Parameter containing:
 tensor([[[[-4.7325e-02, -8.6777e-04, -1.0161e-02, -2.8912e-03,  2.4838e-02],
           [-4.7722e-02, -6.0837e-02, -2.3332e-02,  1.6997e-02,  5.0613e-02],
           [-2.9138e-02, -2.4450e-02,  3.0840e-02,  5.2124e-02,  1.1825e-03],
           [ 2.8588e-02, -1.3355e-02,  3.0386e-02,  6.7025e-02,  4.8705e-02],
           [-5.0003e-0

In [7]:
data_dir = '/home/cybai/MRGAN/data/cifar10'

In [8]:
dataloader = torch.utils.data.DataLoader(
    datasets.CIFAR10(
        data_dir,
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize([0.5] * n_channels, [0.5] * n_channels)]
        ),
    ),
    batch_size=bsize,
    shuffle=True,
    drop_last=True, # to avoid sample images not full (if last batch less than 36)
)

Files already downloaded and verified


In [82]:
z = Variable(torch.from_numpy(np.random.normal(0, 1, (bsize, latent_dim))).float()).to(device)
fake_imgs = generator(z)
g_loss = -torch.mean(discriminator(fake_imgs))

In [93]:
foo = grad(g_loss, all_params_g, create_graph=True)[0]
bar = Variable(torch.from_numpy(np.random.normal(0, 1, (latent_dim, 1))).float()).to(device)
print(foo, bar)
grad(torch.matmul(foo, bar).to(device), all_params_g, torch.ones(foo.shape[0], 1))

tensor([[ 0.0013, -0.0014,  0.0010,  ..., -0.0011,  0.0009, -0.0010],
        [-0.0018,  0.0014,  0.0014,  ...,  0.0011,  0.0002, -0.0005],
        [ 0.0014, -0.0022,  0.0021,  ..., -0.0020,  0.0003, -0.0005],
        ...,
        [-0.0013, -0.0012, -0.0062,  ...,  0.0038,  0.0015,  0.0013],
        [ 0.0040, -0.0017, -0.0019,  ...,  0.0018, -0.0026, -0.0006],
        [ 0.0022,  0.0005, -0.0015,  ..., -0.0027,  0.0056, -0.0002]],
       device='cuda:0', grad_fn=<TBackward>) tensor([[ 1.9229e+00],
        [-3.1958e-01],
        [ 5.3976e-01],
        [ 5.5201e-01],
        [-1.0413e-01],
        [-4.4630e-01],
        [ 1.7233e-01],
        [-1.2153e+00],
        [ 2.0877e-02],
        [-4.4553e-01],
        [ 1.1646e+00],
        [-1.3763e+00],
        [ 2.1485e+00],
        [ 5.0445e-01],
        [-1.2736e+00],
        [-3.1038e-01],
        [ 6.4690e-02],
        [ 6.6110e-01],
        [ 1.3639e+00],
        [ 7.2581e-01],
        [ 6.1449e-01],
        [ 3.6899e-01],
        [-5.296

RuntimeError: invalid gradient at index 0 - expected type TensorOptions(dtype=float, device=cuda:0, layout=Strided, requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)) but got TensorOptions(dtype=float, device=cpu, layout=Strided, requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))

In [55]:
def loss_fn_wrt_theta(trainable_params):

    return g_loss

In [57]:
param_hessian = hessian(loss_fn_wrt_theta, generator)

TypeError: The inputs given to hessian must be either a Tensor or a tuple of Tensors but the given inputs has type <class 'wgan_gp.Generator'>.

In [54]:
layer = generator.model[0]
param_grad = grad(g_loss, layer.weight, retain_graph=True)

RuntimeError: The Tensor returned by the function given to hessian should contain a single element

In [47]:
param_grads = []
for layer in generator.model:
    try:
        param_grad = grad(g_loss, layer.weight, retain_graph=True)
        param_ggrad = grad(param_grad, layer.weight, retain_graph=True)
        param_grads.append(param_grad)
    except:
        print(f"{layer} has no weight")

Linear(in_features=128, out_features=8192, bias=True) has no weight
BatchNorm1d(8192, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True) has no weight
ReLU(inplace=True) has no weight
View() has no weight
ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1)) has no weight
BatchNorm2d(256, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True) has no weight
ReLU(inplace=True) has no weight
ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1)) has no weight
BatchNorm2d(128, eps=1e-05, momentum=0.8, affine=True, track_running_stats=True) has no weight
ReLU(inplace=True) has no weight
ConvTranspose2d(128, 3, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1)) has no weight
Tanh() has no weight


In [46]:
param_grads[0][0].shape

torch.Size([8192, 128])