In [20]:
import torch
import os
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from wgan_gp import *
from utils import *
import opacus

# Parameters

*   image_shape: the number of pixels in each image. We further downsize every image to 64 x 64 to be compatible with the self-attention GAN architecture
*   n_classes: the number of classes (10, since there are 10 different emojis)
*   image_path: path to image folder
*   n_epochs: the number of times to iterate through the entire dataset when training
*   z_dim: the dimension of the noise vector
*   display_step: how often to display/visualize the images
*   batch_size: the number of images per forward/backward pass
*   lr: the learning rate
*   beta_1, beta_2: coefficients used for computing running averages of gradient and its square (used for Adam optimizer)
*   lambda_gp: weight of the gradient penalty
*   device: the device type
*   g_conv_dim: the number of dimensions in the last convolutional layer of the generator
*   d_conv_dim: the number of dimensions in the first convolutional layer of the discriminator
*   ckpt_epoch: how often to save model checkpoint

In [12]:
image_shape = (3, 64, 64)
n_classes = 10
image_path = 'generated/csawgan_gp'
n_epochs = 100
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
lambda_gp = 10
g_conv_dim = 64
d_conv_dim = 64
ckpt_epoch = 2
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [13]:
transform = transforms.Compose([
    transforms.Resize((image_shape[1], image_shape[2])),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataset = ImageFolder(image_path, transform=transform)

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True)

# Initialize generator, discriminator, optimizers

In [18]:
generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, image_shape, n_classes)

gen = Generator(image_shape[1], generator_input_dim, g_conv_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator(discriminator_im_chan, d_conv_dim).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [19]:
print(disc)

Discriminator(
  (l1): Sequential(
    (0): SpectralNorm(
      (module): Conv2d(13, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
    (1): LeakyReLU(negative_slope=0.2)
  )
  (l2): Sequential(
    (0): SpectralNorm(
      (module): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (l3): Sequential(
    (0): SpectralNorm(
      (module): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (l4): Sequential(
    (0): SpectralNorm(
      (module): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (last): Sequ

# Training

In [47]:
def train(start_epoch, n_epochs, dataloader, gen, gen_opt, disc, disc_opt, save_path):
    cur_step = 0
    generator_losses = []
    discriminator_losses = []

    noise_and_labels = False
    fake = False

    fake_image_and_labels = False
    real_image_and_labels = False
    disc_fake_pred = False
    disc_real_pred = False

    for epoch in range(start_epoch, start_epoch + n_epochs):
        # Dataloader returns the batches and the labels
        for real, labels in tqdm(dataloader):
            print(cur_step)
            for param in disc.parameters(): 
                param.accumulated_grads = []
            cur_batch_size = len(real)
            # Flatten the batch of real images from the dataset
            real = real.to(device)

            one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
            image_one_hot_labels = one_hot_labels[:, :, None, None]
            image_one_hot_labels = image_one_hot_labels.repeat(1, 1, image_shape[1], image_shape[2])

            # Zero out the discriminator gradients
            disc_opt.zero_grad()
            # Get noise corresponding to the current batch_size 
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)

            noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
            
            for i in range(cur_batch_size): 
                fake = gen(torch.unsqueeze(noise_and_labels[i], 0))

                fake_image_and_labels = combine_vectors(fake.detach(), torch.unsqueeze(image_one_hot_labels[i], 0))
                real_image_and_labels = combine_vectors(torch.unsqueeze(real[i], 0), torch.unsqueeze(image_one_hot_labels[i], 0))
                disc_fake_pred = disc(fake_image_and_labels)
                disc_real_pred = disc(real_image_and_labels)

                gp = compute_gradient_penalty(disc, torch.unsqueeze(real[i], 0), fake.detach(), torch.unsqueeze(image_one_hot_labels[i], 0), device=device)
                disc_loss = torch.mean(disc_fake_pred) - torch.mean(disc_real_pred) + gp * lambda_gp

                disc_loss.backward(retain_graph=True)
                # disc_opt.step() 

                # Keep track of the average discriminator loss
                discriminator_losses += [disc_loss.item()]

                ### Update generator ###
#                 # Zero out the generator gradients
#                 gen_opt.zero_grad()

                fake_image_and_labels = combine_vectors(fake, torch.unsqueeze(image_one_hot_labels[i], 0))
                disc_fake_pred = disc(fake_image_and_labels)
                gen_loss = -torch.mean(disc_fake_pred)
                gen_loss.backward()
                gen_opt.step()
            
                for param in disc.parameters(): 
                    print(param)
                    print(param.grad)
                    if param.grad is None: 
                        continue
                    per_sample_grad = param.grad.detach().clone()
                    clip_grad_norm_(per_sample_grad, max_norm=1)
                    param.accumulated_grads.append(per_sample_grad)

                # Keep track of the generator losses
                generator_losses += [gen_loss.item()]
            
            # Aggregate back
            for param in disc.parameters(): 
                param.grad = torch.stack(param.accumulated_grads, dim=0)
                
            for param in disc.parameters(): 
                param = param - lr * param.grad
                param += torch.normal(mean=0, std=1.5)
                param.grad = 0

            if cur_step % display_step == 0 and cur_step > 0:
                gen_mean = sum(generator_losses[-display_step:]) / display_step
                disc_mean = sum(discriminator_losses[-display_step:]) / display_step
                print(f"Step {cur_step}: Generator loss: {gen_mean}, discriminator loss: {disc_mean}")
                show_tensor_images(fake)
                show_tensor_images(real)
                step_bins = 20
                x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
                num_examples = (len(generator_losses) // step_bins) * step_bins
                plt.plot(
                    range(num_examples // step_bins), 
                    torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                    label="Generator Loss"
                )
                plt.plot(
                    range(num_examples // step_bins), 
                    torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
                    label="Discriminator Loss"
                )
                plt.legend()
                plt.show()
            elif cur_step == 0:
                print("Start training...")
            cur_step += 1
            
        if epoch % ckpt_epoch == 0:
            ckpt_dict = {
                'epoch': epoch,
                'gen_state_dict': gen.state_dict(),
                'gen_opt': gen_opt.state_dict(),
                'disc_state_dict': disc.state_dict(),
                'disc_opt': disc_opt.state_dict(),
            }
            checkpoint_path = os.path.join(save_path, f'epoch_{epoch}.pth.tar')
            torch.save(ckpt_dict, checkpoint_path)

In [48]:
train(0, n_epochs, dataloader, gen, gen_opt, disc, disc_opt, 'model_ckpt/cSAWGAN-GP-DP')

  0%|          | 0/52 [00:00<?, ?it/s]

0
Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-2.0022e+00, -3.9654e-01, -6.1716e-01, -5.1999e-01,  3.1981e+00,
        -2.2701e+00, -1.9017e+00, -1.7336e-01, -1.9623e-01, -2.0170e-01,
         2.1500e+00, -9.3310e-02,  4.6199e-01,  1.5469e-01, -9.7702e-02,
         3.2199e-02, -3.2110e-02, -1.0994e+00, -4.6526e-03, -3.5807e-02,
         7.1667e

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-3.6980, -0.0419, -1.1190, -0.6686,  4.5780, -3.5377, -3.6029, -0.4951,
        -0.4857, -0.4163,  4.1685, -0.2827,  1.3226,  2.0106,  1.9397, -0.0484,
         0.2368, -1.7990, -0.1335, -0.2298,  1.2040, -0.3955, -5.3568, -0.3308,
        -1.0856, -0.8148, -1.0609, -1.4567,  1.4317,  4.7497,  1.6416,

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-5.6349,  0.3354, -1.6588, -0.7212,  6.8671, -4.8015, -5.4811, -0.9162,
        -0.8253, -0.6849,  6.7161, -0.4903,  2.0053,  3.3688,  2.4124, -0.1180,
         0.5512, -3.0489,  0.6001, -0.4225,  2.1087, -0.2423, -4.6173, -0.6658,
        -1.4908, -1.3146, -1.5775, -1.9284,  3.2844,  7.0225,  2.6240,

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-7.5280e+00,  3.1188e-01, -2.1262e+00, -1.0728e+00,  8.1743e+00,
        -5.9031e+00, -6.5848e+00, -1.2725e+00, -1.1386e+00, -8.8735e-01,
         8.5704e+00, -7.3480e-01,  2.8592e+00,  4.5801e+00, -1.0970e+00,
        -1.2520e-01,  3.5965e-01, -4.3800e+00,  4.0833e-01, -6.8302e-01,
         2.7918e+0

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([ -9.8768,  -0.0651,  -2.6982,  -1.8726,  10.5092,  -8.3343,  -8.6403,
         -1.4206,  -1.2275,  -1.0268,  10.6027,  -0.8069,   3.5317,   5.2153,
         -1.6994,  -0.1598,   0.5991,  -5.6757,  -0.2179,  -0.7286,   2.6047,
          1.5968,  -3.2071,  -0.9152,  -3.5275,  -2.0390,  -2.6555,  -4.1075

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-12.0891,   0.0837,  -3.3980,  -2.2966,  10.1841, -11.1196, -11.5773,
         -1.5486,  -1.3516,  -1.3207,  13.1585,  -0.8211,   4.5814,   9.0427,
         -0.9143,  -0.3155,   1.5890,  -5.6362,   0.0812,  -0.5938,   2.6793,
          1.1747,  -5.7086,  -0.9616,  -4.0317,  -2.4003,  -3.3747,  -4.6284

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-13.2149,   0.2653,  -3.7704,  -2.2849,  11.4610, -11.6258, -11.7542,
         -1.8130,  -1.6896,  -1.4805,  14.9247,  -1.0103,   4.9825,  10.0562,
         -3.8314,  -0.2859,   1.4761,  -6.6129,   0.1968,  -0.8474,   3.7434,
         -3.3385,  -4.5977,  -1.1796,  -4.1992,  -2.7897,  -3.7018,  -5.1294

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-15.1887,   0.0913,  -4.2731,  -2.9156,  17.4641, -13.9384, -13.9992,
         -1.9451,  -1.7642,  -1.5839,  16.9837,  -1.0586,   5.5605,  10.1163,
         -1.8788,  -0.3152,   1.3605,  -7.6376,  -0.1224,  -0.8999,   3.6889,
         -3.2470,  -2.1163,  -1.1917,  -5.0986,  -3.1613,  -4.2543,  -6.4015

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-1.7129e+01,  6.7299e-01, -4.6623e+00, -3.0314e+00,  1.6691e+01,
        -1.4376e+01, -1.4741e+01, -2.4232e+00, -2.2005e+00, -1.7896e+00,
         1.8672e+01, -1.3919e+00,  6.2329e+00,  1.1482e+01, -2.1565e+00,
        -3.4427e-01,  9.6694e-01, -9.4194e+00, -6.8321e-01, -1.3151e+00,
         4.5075e+0

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-18.4341,   0.6178,  -5.0361,  -2.6362,  16.6440, -14.9657, -13.9536,
         -2.8631,  -2.6432,  -1.8997,  20.5635,  -1.6879,   6.2378,   9.2435,
         -6.5136,  -0.1649,  -0.0894, -11.3287,  -0.8259,  -1.7009,   6.2150,
         -9.3804,  -1.7333,  -1.9282,  -5.5396,  -4.1956,  -5.1054,  -8.4400

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-2.0895e+01,  7.2081e-02, -5.8164e+00, -3.6748e+00,  1.9546e+01,
        -1.8472e+01, -1.7022e+01, -2.9028e+00, -2.6727e+00, -2.1449e+00,
         2.3068e+01, -1.6089e+00,  7.0430e+00,  9.3126e+00, -3.4988e+00,
        -1.8690e-01,  3.7380e-01, -1.2028e+01, -1.1909e+00, -1.5102e+00,
         6.1682e+0

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-2.2582e+01, -1.4043e-01, -6.1192e+00, -3.8745e+00,  1.9781e+01,
        -1.9530e+01, -1.7744e+01, -3.1956e+00, -2.9625e+00, -2.1866e+00,
         2.4525e+01, -1.7948e+00,  6.9825e+00,  1.1712e+01,  4.7620e-01,
        -1.0703e-01, -6.1134e-01, -1.3824e+01, -2.0163e+00, -1.8108e+00,
         6.7462e+0

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-2.4786e+01, -6.4992e-02, -6.5760e+00, -4.5915e+00,  2.1172e+01,
        -2.1062e+01, -1.9573e+01, -3.4351e+00, -3.1491e+00, -2.3415e+00,
         2.6367e+01, -1.9239e+00,  7.7191e+00,  1.2638e+01,  5.5112e-01,
        -1.2313e-01, -6.5295e-01, -1.5124e+01, -2.4267e+00, -1.9479e+00,
         7.2402e+0

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-2.7018e+01,  1.6392e-02, -7.0317e+00, -4.9203e+00,  2.4409e+01,
        -2.2361e+01, -2.0961e+01, -3.7009e+00, -3.3716e+00, -2.5196e+00,
         2.8259e+01, -2.1387e+00,  8.7907e+00,  1.2861e+01, -8.3015e-01,
        -1.7535e-01, -4.9361e-01, -1.6605e+01, -2.9534e+00, -2.0689e+00,
         7.6434e+0

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-27.7867,  -1.2154,  -7.2810,  -4.7905,  25.5104, -22.7888, -19.6867,
         -3.9312,  -3.5688,  -2.5724,  29.6068,  -2.1972,   8.6309,  14.0080,
         -9.1565,   0.0915,  -0.8779, -17.9761,  -2.9365,  -2.3270,   8.3309,
         -7.8885,  -2.2438,  -2.6683,  -9.1712,   6.5581,  -7.2637, -14.3597

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-2.9704e+01, -6.1630e-01, -7.7421e+00, -4.9015e+00,  2.0736e+01,
        -2.3847e+01, -2.1437e+01, -4.3071e+00, -3.8943e+00, -2.7389e+00,
         3.1608e+01, -2.4792e+00,  9.5558e+00,  1.4234e+01, -7.2145e+00,
        -5.4801e-03, -7.7824e-01, -1.8975e+01, -2.9470e+00, -2.5800e+00,
         9.1354e+0

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-31.3865,  -1.2388,  -7.9623,  -5.1881,  21.6580, -24.5923, -21.3048,
         -4.6123,  -4.1373,  -2.7619,  32.7512,  -2.6418,   9.5102,  16.1113,
         -7.9328,   0.1767,  -1.8660, -20.9922,  -3.4537,  -2.8788,   9.6975,
         -7.4877,  -0.2819,  -3.0921, -10.1687,   5.6374,  -8.1914, -16.8793

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-33.7159,  -1.8258,  -8.5384,  -5.7092,  30.4897, -27.2915, -23.6200,
         -4.7140,  -4.2574,  -2.8627,  34.7014,  -2.6306,   9.5056,  22.4404,
         -0.8247,   0.1653,  -2.1595, -22.2024,  -3.9875,  -2.8767,   9.7068,
         -6.5082,  -0.8871,  -3.1354, -11.1603,   5.2348,  -8.8120, -18.0664

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-3.5638e+01, -1.9700e+00, -8.9867e+00, -5.9269e+00,  3.1750e+01,
        -2.8702e+01, -2.5181e+01, -4.9944e+00, -4.4788e+00, -3.0123e+00,
         3.6435e+01, -2.7704e+00,  1.0041e+01,  2.3531e+01,  1.7808e+00,
         1.7701e-01, -2.1430e+00, -2.3020e+01, -4.4214e+00, -3.0172e+00,
         9.9238e+0

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-37.7061,  -2.0729,  -9.6085,  -6.7849,  36.6203, -30.9594, -27.2923,
         -5.2502,  -4.6733,  -3.1960,  38.6629,  -2.9087,  10.7606,  23.7099,
          1.6277,   0.1538,  -1.9975, -24.3579,  -5.0361,  -3.1143,  10.4435,
         -6.7152,   4.2086,  -3.3796, -13.0238,   4.4548,  -9.8674, -20.0519

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-39.1156,  -1.5693,  -9.9214,  -6.9000,  36.7048, -31.6560, -28.3369,
         -5.5325,  -4.9298,  -3.3640,  40.3199,  -3.0637,  11.5044,  24.6208,
          4.8609,   0.1026,  -1.7833, -25.0554,  -5.0938,  -3.3059,  10.8185,
        -10.5557,   3.8542,  -3.5623, -13.3840,   4.0804, -10.2490, -20.4281

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-4.1295e+01, -1.8138e+00, -1.0422e+01, -7.5054e+00,  3.8668e+01,
        -3.3500e+01, -3.0064e+01, -5.7237e+00, -5.0738e+00, -3.5210e+00,
         4.2432e+01, -3.1536e+00,  1.2163e+01,  2.5522e+01,  5.0167e+00,
         7.2927e-02, -1.6179e+00, -2.6324e+01, -5.5734e+00, -3.3800e+00,
         1.1102e+0

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-44.8203,  -4.9061, -10.8336,  -8.0901,  52.8977, -36.2438, -29.9064,
         -6.0056,  -5.2245,  -3.4367,  43.2400,  -3.0386,  10.2216,  25.9328,
          4.1640,   0.5104,  -2.5313, -28.7937,  -6.9382,  -3.5399,  11.5281,
        -16.0345,  -1.4390,  -3.7330, -16.2773,   3.2523, -11.3138, -24.1573

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-47.3610,  -5.1745, -11.4122,  -8.5843,  54.7848, -38.8728, -32.6999,
         -6.2463,  -5.3776,  -3.5818,  45.7020,  -3.1286,  10.5658,  26.8263,
          5.0926,   0.4176,  -2.6452, -30.3760,  -8.1032,  -3.5985,  11.7105,
        -16.2087,  -2.2757,  -3.8708, -17.4454,  15.7701, -11.9751, -25.3664

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-49.7087,  -6.1408, -11.8555,  -9.5835,  57.1650, -40.8097, -33.1438,
         -6.4088,  -5.5432,  -3.7122,  47.2627,  -3.1290,  10.6647,  27.5034,
          0.1341,   0.6383,  -2.5303, -31.5979,  -8.2535,  -3.6198,  12.2795,
        -15.1030,  -2.9871,  -3.9513, -18.1708,  15.5055, -12.4532, -27.0229

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-51.1292,  -6.1906, -12.3036,  -9.7666,  58.5916, -41.5780, -33.9989,
         -6.6885,  -5.8059,  -3.8837,  48.8516,  -3.3411,  11.3005,  28.3814,
          1.8022,   0.6495,  -2.6221, -33.0771,  -8.3275,  -3.8328,  12.8573,
        -15.1221,  -4.3378,  -4.1604, -18.8486,  15.1693, -12.8590, -27.6976

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-53.2494,  -6.2362, -12.8368, -10.4568,  60.5453, -44.1207, -36.5342,
         -6.8601,  -5.8790,  -4.0868,  51.1910,  -3.3503,  12.1217,  29.1781,
          2.8421,   0.5545,  -1.8147, -33.4579,  -8.8331,  -3.8011,  12.9155,
        -12.4510,  -6.1876,  -4.1911, -19.6713,  14.8633, -13.4178, -28.3967

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-55.3659,  -6.3286, -13.3323, -11.1495,  62.6274, -46.0167, -38.3970,
         -6.9515,  -6.0333,  -4.2740,  53.0164,  -3.3803,  12.7183,  29.9011,
          1.8742,   0.5354,  -1.1378, -34.0917,  -8.8697,  -3.8176,  12.9963,
        -13.1260,  -7.7349,  -4.2422, -20.3930,  14.5544, -13.9776, -29.0010

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-57.3595,  -6.0828, -13.7662, -11.2512,  61.6268, -46.8969, -40.0326,
         -7.3344,  -6.3393,  -4.4449,  54.9950,  -3.6403,  13.4416,  30.0931,
          1.4407,   0.4747,  -1.4904, -35.9528,  -9.5287,  -4.1134,  13.7449,
        -13.1212,  -6.2292,  -4.4997, -21.0293,  14.0006, -14.4870, -29.8852

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-58.5293,  -6.1886, -14.2044, -11.5225,  63.0247, -48.1950, -40.8463,
         -7.5633,  -6.6079,  -4.5662,  56.5701,  -3.7892,  13.4400,  30.9936,
         -6.4350,   0.4731,  -1.6762, -37.0522,  -9.7451,  -4.3242,  14.6186,
        -16.8783,  -7.6827,  -4.6247, -21.3185,  13.6127, -14.8580, -30.5763

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-60.0012,  -6.0675, -14.5189, -11.5300,  63.9390, -48.6430, -41.1209,
         -7.8632,  -6.9001,  -4.6906,  58.1569,  -4.0322,  13.9634,  32.1118,
         -5.8023,   0.4833,  -2.2343, -38.5939, -10.0741,  -4.6375,  15.4769,
        -18.8517,  -8.0652,  -4.8332, -21.6886,  13.1708, -15.2178, -31.4578

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-61.8544,  -5.5116, -15.0235, -11.9566,  65.6724, -50.4281, -43.4940,
         -8.0463,  -7.0799,  -4.8779,  60.0625,  -4.1569,  14.8422,  32.7687,
         -4.3267,   0.3248,  -1.8865, -39.2766, -10.2595,  -4.7394,  16.0483,
        -20.9161, -10.4877,  -4.9109, -21.9476,  12.7558, -15.7599, -31.9087

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-64.5432,  -6.6005, -15.6860, -12.7036,  77.5499, -53.6747, -46.1217,
         -8.1960,  -7.1710,  -5.0207,  62.1892,  -4.0859,  14.5262,  32.7530,
         -4.3238,   0.3577,  -1.5178, -40.2746, -11.1059,  -4.6843,  16.1007,
        -21.2118, -17.6317,  -4.9250, -23.2883,  12.3418, -16.4631, -33.3108

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-6.6715e+01, -6.8314e+00, -1.6179e+01, -1.3594e+01,  7.9266e+01,
        -5.5791e+01, -4.8017e+01, -8.3622e+00, -7.3044e+00, -5.1738e+00,
         6.4241e+01, -4.1465e+00,  1.5021e+01,  3.3562e+01, -2.9872e+00,
         3.3685e-01, -1.3453e+00, -4.1626e+01, -1.1888e+01, -4.7635e+00,
         1.6324e+0

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-68.8810,  -8.3365, -16.4906, -13.6643,  78.6081, -56.9522, -47.9758,
         -8.6178,  -7.4494,  -5.2181,  65.2576,  -4.2116,  14.7511,  33.3785,
         -2.7424,   0.5082,  -1.9854, -43.6420, -12.9257,  -4.9469,  16.2890,
        -24.9544, -16.5660,  -5.1683, -25.9559,  11.5994, -17.3753, -35.9613

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-7.0926e+01, -8.3160e+00, -1.6915e+01, -1.4425e+01,  8.0493e+01,
        -5.8707e+01, -4.9653e+01, -8.7438e+00, -7.5770e+00, -5.3677e+00,
         6.6816e+01, -4.2661e+00,  1.5468e+01,  3.3906e+01,  1.0300e+00,
         4.9783e-01, -1.7344e+00, -4.4562e+01, -1.2891e+01, -5.0231e+00,
         1.6564e+0

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-7.3476e+01, -7.5390e+00, -1.7517e+01, -1.4620e+01,  8.1923e+01,
        -6.0271e+01, -5.2130e+01, -9.0771e+00, -7.9047e+00, -5.6523e+00,
         6.9391e+01, -4.4740e+00,  1.6753e+01,  4.0677e+01, -1.2401e-01,
         3.3318e-01, -1.4330e+00, -4.5095e+01, -1.2814e+01, -5.1951e+00,
         1.7172e+0

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-77.0912,  -9.2729, -17.9112, -15.4401,  85.5702, -62.9566, -53.7813,
         -9.4304,  -8.0862,  -5.6092,  70.9650,  -4.5368,  15.3558,  41.6819,
          5.4157,   0.4876,  -3.8019, -49.1503, -15.5852,  -5.5283,  17.3822,
        -24.8935, -18.7910,  -5.6211, -29.1512,  10.2074, -19.1903, -39.6408

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-79.9887, -10.0957, -18.6376, -16.0277,  87.6102, -65.8721, -56.6724,
         -9.5580,  -8.1765,  -5.8783,  73.2176,  -4.4242,  15.6602,  44.0105,
         13.2082,   0.4331,  -2.9224, -49.6712, -15.8836,  -5.4288,  17.3446,
        -25.2165, -21.7167,  -5.6797, -30.3537,   9.8636, -19.9036, -40.0709

Parameter containing:
tensor([ 0.0684,  0.1323, -0.1642,  0.0823,  0.0509,  0.1732,  0.0478, -0.0551,
        -0.2007, -0.2059,  0.1489, -0.1646,  0.2332,  0.1242, -0.0247, -0.1376,
         0.0991, -0.0083,  0.0399, -0.1348,  0.0592,  0.0450, -0.0045, -0.1987,
         0.1918, -0.0792, -0.0139,  0.1729, -0.0468,  0.0981,  0.1689,  0.1505,
        -0.0457,  0.1631,  0.0436,  0.0410, -0.2197,  0.1022,  0.1263,  0.1583,
         0.0615,  0.0542,  0.0362,  0.0455,  0.0895, -0.1509, -0.2119,  0.1106,
        -0.0671,  0.1453, -0.1137, -0.0928,  0.2271,  0.2809,  0.1725, -0.0909,
         0.1248,  0.2105,  0.1663, -0.0871,  0.1872, -0.2238,  0.1055,  0.2279],
       requires_grad=True)
tensor([-82.3007,  -9.9444, -19.1543, -16.3029,  92.3792, -67.4176, -58.4210,
         -9.9104,  -8.5220,  -6.0517,  75.3847,  -4.6340,  16.0610,  44.0425,
         12.9506,   0.3759,  -3.2036, -51.1966, -16.5422,  -5.6757,  18.3025,
        -25.2601, -24.3248,  -5.9271, -30.8628,   9.3133, -20.4846, -41.0196

KeyboardInterrupt: 