In [1]:
# VAE architecture


import torch
from models.base import BaseVAE
from torch import nn
from torch.nn import functional as F
from models.types_ import *


class ResidualBlockEncoder(nn.Module):
    
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 stride: int,
                 **kwargs) -> None:
        super(ResidualBlockEncoder, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                              kernel_size= 3, stride= stride, padding  = 1)
        self.conv2 = nn.Conv2d(in_channels, out_channels,
                              kernel_size= 3, stride= stride, padding  = 1)
        self.norm = nn.BatchNorm2d(in_channels)
        self.relu = nn.LeakyReLU()
        
    def forward(self, x):
        x = self.norm(x)
        if self.stride != 1 or self.in_channels != self.out_channels:
            temp = self.conv2(x)
        else: temp = 1*x
        x = self.relu(self.conv1(x))
        x = x + temp
        return x
    
    
class ResidualBlockDecoder(nn.Module):
    
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 stride: int,
                 **kwargs) -> None:
        super(ResidualBlockDecoder, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        
        self.conv1 = nn.ConvTranspose2d(in_channels, 
                                        out_channels,
                                        kernel_size= 3, 
                                        stride= stride, 
                                        padding  = 1,
                                        output_padding=stride-1)
        
        self.conv2 = nn.ConvTranspose2d(in_channels, 
                                        out_channels,
                                        kernel_size= 3, 
                                        stride= stride, 
                                        padding  = 1,
                                        output_padding=stride-1)
        self.norm = nn.BatchNorm2d(in_channels)
        self.relu = nn.LeakyReLU()
        
    def forward(self, x):
        x = self.norm(x)
        if self.stride != 1 or self.in_channels != self.out_channels:
            temp = self.conv2(x)
        else: temp = 1*x
        x = self.relu(self.conv1(x))
        x = x + temp
        return x

        
class VAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 in_size: int, 
                 latent_dim: int,
                 hidden_dims: List = None,
                 **kwargs) -> None:
        super(VAE, self).__init__()

        self.latent_dim = latent_dim
        self.nb_last_channels = hidden_dims[-1]
        self.out_channels = in_channels
        

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]
            div = 2**5  # Model reduces size by this factor           
        else:
            div = 2**len(hidden_dims)   # Model reduces size by this factor   
            
        # Make sure input dimension is usable     
        if in_size%div==0:
            self.smallest_size = int(in_size/div)
        else: raise ValueError('Input size not compatible with number of hidden layers.')

            

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    ResidualBlockEncoder(in_channels=in_channels, out_channels=in_channels, stride= 1),
                    ResidualBlockEncoder(in_channels=in_channels, out_channels=in_channels, stride= 1),
                    ResidualBlockEncoder(in_channels=in_channels, out_channels=h_dim, stride= 2))
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*self.smallest_size**2, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*self.smallest_size**2, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * self.smallest_size**2)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    ResidualBlockDecoder(hidden_dims[i], hidden_dims[i], stride = 1),
                    ResidualBlockDecoder(hidden_dims[i], hidden_dims[i], stride = 1),
                    ResidualBlockDecoder(hidden_dims[i], hidden_dims[i+1], stride = 2))
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= self.out_channels,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, self.nb_last_channels, self.smallest_size, self.smallest_size)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(recons, input)


        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

In [2]:
# Define Trainer

from torchvision.utils import save_image

# Achtung: Momentan wird im Trainer nur Reconstruction Loss verwendet
class Trainer(object):
    def __init__(self, model,
                 optimizer, loss_function,
                 loader_train, loader_val,
                 dtype, device, **in_params):
        """
        :param model: PyTorch model of the neural network

        :param optimizer: PyTorch optimizer

        :param print_every: How often should we print the loss during training
        """
        # Create attributes:
        self.device = device
        self.model = model.to(device=self.device)  # move the model parameters to CPU/GPU
        self.optimizer = optimizer
        self.loss_function = loss_function
        self.loader_train = loader_train
        self.loader_val = loader_val
        self.print_every = in_params["print_every"]
        self.dtype = dtype
        self.batch_size = in_params["batch_size"]
        self.input_size = in_params["input_size"]
        self.path = in_params["path"]


    def train_model(self, epoch):
        """
        - epoch: An integer giving the epoch
        """
        train_loss = 0
        self.model.train()  # put model to training mode
        for t, (input,_) in enumerate(self.loader_train):
            
            input = input.to(device=self.device, dtype=self.dtype)  # move to device, e.g. GPU
            
            # do a step in training
            args = self.model(input)
            loss = self.loss_function(*args,**{'M_N':self.batch_size/len(self.loader_train)})['Reconstruction_Loss']
            self.optimizer.zero_grad()
            loss.backward() 
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),1)
            train_loss += loss.item() # accumulate for average loss
            self.optimizer.step()

            # print loss
            if t % self.print_every == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, t * len(args[1]), len(self.loader_train.dataset),
                    100. * t / len(self.loader_train),
                    loss.item() / len(args[1])))
        # print average loss
        print('====> Epoch: {} Average loss: {:.6f}'.format(
              epoch, train_loss / len(self.loader_train.dataset)))

    def test_model(self, epoch):
        self.model.eval() # Put model to evaluation mode
        test_loss = 0.

        with torch.no_grad():
            # During validation, we accumulate these values across the whole dataset and then average at the end:
            for i, (input,_) in enumerate(self.loader_val):
                input = input.to(device=self.device, dtype=self.dtype)  # move to device, e.g. GPU
       
                # compute loss and accumulate
                args = self.model(input)
                test_loss += self.loss_function(*args,**{'M_N':self.batch_size/len(self.loader_val)})['Reconstruction_Loss'].item()
                if i == 0:
                    n = min(args[1].size(0), 8)
                    comparison = torch.cat([args[1][:n],
                                          args[0].view(self.batch_size, self.model.out_channels, self.input_size, self.input_size)[:n]])
                    save_image(comparison.cpu(),
                             self.path + '/reconstruction_' + str(epoch) + '.png', nrow=n)

        # print average loss
        test_loss /= len(self.loader_val.dataset)
        print('====> Test set loss: {:.6f}'.format(test_loss))
        
    def train_and_test(self, epochs):
        for e in range(1,epochs+1):
            self.train_model(e)
            self.test_model(e)
            with torch.no_grad():
                sample = self.model.sample(64,device)
                save_image(sample.view(64, self.model.out_channels, self.input_size, self.input_size),
                           self.path + '/sample_' + str(e) + '.png')

In [None]:
# Train VAE

from torchvision import datasets, transforms

in_params = {"batch_size": 32,
        "epochs": 2000,
        "no_cuda": False,
        "seed": 1,
        "print_every": 10,
        "input_size": 64,
        "path": 'results_trial'
        }
in_params["cuda"] = not in_params["no_cuda"] and torch.cuda.is_available()
torch.manual_seed(in_params["seed"])

device = torch.device("cuda" if in_params["cuda"] else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if in_params["cuda"] else {}


model = VAE(in_channels=3,
           in_size=in_params["input_size"],
           latent_dim=250,
           hidden_dims=[32, 64, 128, 160])

# transformations of input images before feeding into nn
transformations = transforms.Compose([
    transforms.Resize(128),
    transforms.RandomCrop(in_params["input_size"]),
    transforms.ColorJitter(0.4,0.4,0.4,0.0),
    transforms.ToTensor()])

#load data into DataLoader
train_dataset = datasets.ImageFolder('data_to_try/full_data/training', 
                                 transform=transformations
                                   )
test_dataset = datasets.ImageFolder('data_to_try/full_data/test', 
                                 transform=transformations
                                   )



train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=in_params["batch_size"], 
                                           shuffle=True,
                                           drop_last=True,
                                           **kwargs
                                           )
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=in_params["batch_size"], 
                                          shuffle=True,
                                          drop_last=True,
                                          **kwargs
                                          )

# Build the optimizer:
params = model.parameters()
learning_rate = 1e-4
optimizer = torch.optim.AdamW(params, lr=learning_rate)


# Build the trainer with the Soresen-Dice loss you implemented:
trainer = Trainer(model, optimizer, model.loss_function,
        train_loader, test_loader, torch.float32, device,**in_params )

# Start training:
trainer.train_and_test(in_params["epochs"])

====> Epoch: 1 Average loss: 0.002859
====> Test set loss: 0.001370
====> Epoch: 2 Average loss: 0.000854
====> Test set loss: 0.000883
====> Epoch: 3 Average loss: 0.000601
====> Test set loss: 0.000691
====> Epoch: 4 Average loss: 0.000534
====> Test set loss: 0.000644
====> Epoch: 5 Average loss: 0.000500
====> Test set loss: 0.000607
====> Epoch: 6 Average loss: 0.000474
====> Test set loss: 0.000609
====> Epoch: 7 Average loss: 0.000453
====> Test set loss: 0.000606
====> Epoch: 8 Average loss: 0.000417
====> Test set loss: 0.000585
====> Epoch: 9 Average loss: 0.000425
====> Test set loss: 0.000566
====> Epoch: 10 Average loss: 0.000397
====> Test set loss: 0.000490
====> Epoch: 11 Average loss: 0.000400
====> Test set loss: 0.000498
====> Epoch: 12 Average loss: 0.000386
====> Test set loss: 0.000507
====> Epoch: 13 Average loss: 0.000383
====> Test set loss: 0.000503
====> Epoch: 14 Average loss: 0.000369
====> Test set loss: 0.000509
====> Epoch: 15 Average loss: 0.000370
====

====> Test set loss: 0.000378
====> Epoch: 28 Average loss: 0.000256
====> Test set loss: 0.000394
====> Epoch: 29 Average loss: 0.000247
====> Test set loss: 0.000394
====> Epoch: 30 Average loss: 0.000244
====> Test set loss: 0.000340
====> Epoch: 31 Average loss: 0.000248
====> Test set loss: 0.000351
====> Epoch: 32 Average loss: 0.000247
====> Test set loss: 0.000361
====> Epoch: 33 Average loss: 0.000241
====> Test set loss: 0.000348
====> Epoch: 34 Average loss: 0.000225
====> Test set loss: 0.000352
====> Epoch: 35 Average loss: 0.000236
====> Test set loss: 0.000350
====> Epoch: 36 Average loss: 0.000216
====> Test set loss: 0.000325
====> Epoch: 37 Average loss: 0.000228
====> Test set loss: 0.000332
====> Epoch: 38 Average loss: 0.000218
====> Test set loss: 0.000357
====> Epoch: 39 Average loss: 0.000220
====> Test set loss: 0.000366
====> Epoch: 40 Average loss: 0.000228
====> Test set loss: 0.000343
====> Epoch: 41 Average loss: 0.000223
====> Test set loss: 0.000362
====

====> Epoch: 54 Average loss: 0.000214
====> Test set loss: 0.000311
====> Epoch: 55 Average loss: 0.000201
====> Test set loss: 0.000296
====> Epoch: 56 Average loss: 0.000204
====> Test set loss: 0.000306
====> Epoch: 57 Average loss: 0.000202
====> Test set loss: 0.000315
====> Epoch: 58 Average loss: 0.000201
====> Test set loss: 0.000303
====> Epoch: 59 Average loss: 0.000188
====> Test set loss: 0.000303
====> Epoch: 60 Average loss: 0.000180
====> Test set loss: 0.000306
====> Epoch: 61 Average loss: 0.000185
====> Test set loss: 0.000289
====> Epoch: 62 Average loss: 0.000185
====> Test set loss: 0.000306
====> Epoch: 63 Average loss: 0.000196
====> Test set loss: 0.000283
====> Epoch: 64 Average loss: 0.000194
====> Test set loss: 0.000297
====> Epoch: 65 Average loss: 0.000184
====> Test set loss: 0.000304
====> Epoch: 66 Average loss: 0.000183
====> Test set loss: 0.000289
====> Epoch: 67 Average loss: 0.000191
====> Test set loss: 0.000305
====> Epoch: 68 Average loss: 0.00

====> Epoch: 81 Average loss: 0.000174
====> Test set loss: 0.000259
====> Epoch: 82 Average loss: 0.000175
====> Test set loss: 0.000282
====> Epoch: 83 Average loss: 0.000163
====> Test set loss: 0.000287
====> Epoch: 84 Average loss: 0.000167
====> Test set loss: 0.000286
====> Epoch: 85 Average loss: 0.000170
====> Test set loss: 0.000304
====> Epoch: 86 Average loss: 0.000173
====> Test set loss: 0.000292
====> Epoch: 87 Average loss: 0.000168
====> Test set loss: 0.000280
====> Epoch: 88 Average loss: 0.000176
====> Test set loss: 0.000265
====> Epoch: 89 Average loss: 0.000161
====> Test set loss: 0.000284
====> Epoch: 90 Average loss: 0.000163
====> Test set loss: 0.000260
====> Epoch: 91 Average loss: 0.000157
====> Test set loss: 0.000282
====> Epoch: 92 Average loss: 0.000157
====> Test set loss: 0.000259
====> Epoch: 93 Average loss: 0.000162
====> Test set loss: 0.000302
====> Epoch: 94 Average loss: 0.000162
====> Test set loss: 0.000266
====> Epoch: 95 Average loss: 0.00

====> Test set loss: 0.000255
====> Epoch: 108 Average loss: 0.000149
====> Test set loss: 0.000258
====> Epoch: 109 Average loss: 0.000147
====> Test set loss: 0.000263
====> Epoch: 110 Average loss: 0.000149
====> Test set loss: 0.000269
====> Epoch: 111 Average loss: 0.000148
====> Test set loss: 0.000242
====> Epoch: 112 Average loss: 0.000157
====> Test set loss: 0.000261
====> Epoch: 113 Average loss: 0.000151
====> Test set loss: 0.000259
====> Epoch: 114 Average loss: 0.000143
====> Test set loss: 0.000267
====> Epoch: 115 Average loss: 0.000154
====> Test set loss: 0.000240
====> Epoch: 116 Average loss: 0.000153
====> Test set loss: 0.000256
====> Epoch: 117 Average loss: 0.000150
====> Test set loss: 0.000262
====> Epoch: 118 Average loss: 0.000149
====> Test set loss: 0.000267
====> Epoch: 119 Average loss: 0.000141
====> Test set loss: 0.000253
====> Epoch: 120 Average loss: 0.000147
====> Test set loss: 0.000260
====> Epoch: 121 Average loss: 0.000146
====> Test set loss:

====> Epoch: 134 Average loss: 0.000140
====> Test set loss: 0.000265
====> Epoch: 135 Average loss: 0.000133
====> Test set loss: 0.000245
====> Epoch: 136 Average loss: 0.000140
====> Test set loss: 0.000241
====> Epoch: 137 Average loss: 0.000140
====> Test set loss: 0.000267
====> Epoch: 138 Average loss: 0.000146
====> Test set loss: 0.000259
====> Epoch: 139 Average loss: 0.000138
====> Test set loss: 0.000269
====> Epoch: 140 Average loss: 0.000146
====> Test set loss: 0.000255
====> Epoch: 141 Average loss: 0.000147
====> Test set loss: 0.000256
====> Epoch: 142 Average loss: 0.000140
====> Test set loss: 0.000273
====> Epoch: 143 Average loss: 0.000136
====> Test set loss: 0.000240
====> Epoch: 144 Average loss: 0.000133
====> Test set loss: 0.000259
====> Epoch: 145 Average loss: 0.000134
====> Test set loss: 0.000244
====> Epoch: 146 Average loss: 0.000130
====> Test set loss: 0.000257
====> Epoch: 147 Average loss: 0.000139
====> Test set loss: 0.000250
====> Epoch: 148 Ave

====> Epoch: 160 Average loss: 0.000126
====> Test set loss: 0.000252
====> Epoch: 161 Average loss: 0.000130
====> Test set loss: 0.000251
====> Epoch: 162 Average loss: 0.000129
====> Test set loss: 0.000258
====> Epoch: 163 Average loss: 0.000133
====> Test set loss: 0.000241
====> Epoch: 164 Average loss: 0.000133
====> Test set loss: 0.000236
====> Epoch: 165 Average loss: 0.000131
====> Test set loss: 0.000267
====> Epoch: 166 Average loss: 0.000128
====> Test set loss: 0.000250
====> Epoch: 167 Average loss: 0.000130
====> Test set loss: 0.000223
====> Epoch: 168 Average loss: 0.000131
====> Test set loss: 0.000260
====> Epoch: 169 Average loss: 0.000131
====> Test set loss: 0.000230
====> Epoch: 170 Average loss: 0.000125
====> Test set loss: 0.000235
====> Epoch: 171 Average loss: 0.000129
====> Test set loss: 0.000242
====> Epoch: 172 Average loss: 0.000128
====> Test set loss: 0.000251
====> Epoch: 173 Average loss: 0.000129
====> Test set loss: 0.000256
====> Epoch: 174 Ave

====> Epoch: 186 Average loss: 0.000126
====> Test set loss: 0.000226
====> Epoch: 187 Average loss: 0.000133
====> Test set loss: 0.000272
====> Epoch: 188 Average loss: 0.000132
====> Test set loss: 0.000250
====> Epoch: 189 Average loss: 0.000120
====> Test set loss: 0.000231
====> Epoch: 190 Average loss: 0.000129
====> Test set loss: 0.000235
====> Epoch: 191 Average loss: 0.000126
====> Test set loss: 0.000233
====> Epoch: 192 Average loss: 0.000136
====> Test set loss: 0.000287
====> Epoch: 193 Average loss: 0.000141
====> Test set loss: 0.000271
====> Epoch: 194 Average loss: 0.000129
====> Test set loss: 0.000237
====> Epoch: 195 Average loss: 0.000123
====> Test set loss: 0.000233
====> Epoch: 196 Average loss: 0.000120
====> Test set loss: 0.000219
====> Epoch: 197 Average loss: 0.000123
====> Test set loss: 0.000249
====> Epoch: 198 Average loss: 0.000126
====> Test set loss: 0.000255
====> Epoch: 199 Average loss: 0.000121
====> Test set loss: 0.000236
====> Epoch: 200 Ave

====> Epoch: 213 Average loss: 0.000119
====> Test set loss: 0.000236
====> Epoch: 214 Average loss: 0.000113
====> Test set loss: 0.000249
====> Epoch: 215 Average loss: 0.000118
====> Test set loss: 0.000246
====> Epoch: 216 Average loss: 0.000116
====> Test set loss: 0.000229
====> Epoch: 217 Average loss: 0.000116
====> Test set loss: 0.000260
====> Epoch: 218 Average loss: 0.000122
====> Test set loss: 0.000248
====> Epoch: 219 Average loss: 0.000123
====> Test set loss: 0.000231
====> Epoch: 220 Average loss: 0.000114
====> Test set loss: 0.000263
====> Epoch: 221 Average loss: 0.000117
====> Test set loss: 0.000249
====> Epoch: 222 Average loss: 0.000112
====> Test set loss: 0.000233
====> Epoch: 223 Average loss: 0.000119
====> Test set loss: 0.000243
====> Epoch: 224 Average loss: 0.000116
====> Test set loss: 0.000224
====> Epoch: 225 Average loss: 0.000113
====> Test set loss: 0.000243
====> Epoch: 226 Average loss: 0.000115
====> Test set loss: 0.000219
====> Epoch: 227 Ave

====> Epoch: 239 Average loss: 0.000107
====> Test set loss: 0.000236
====> Epoch: 240 Average loss: 0.000115
====> Test set loss: 0.000234
====> Epoch: 241 Average loss: 0.000110
====> Test set loss: 0.000249
====> Epoch: 242 Average loss: 0.000112
====> Test set loss: 0.000233
====> Epoch: 243 Average loss: 0.000111
====> Test set loss: 0.000240
====> Epoch: 244 Average loss: 0.000112
====> Test set loss: 0.000241
====> Epoch: 245 Average loss: 0.000111
====> Test set loss: 0.000225
====> Epoch: 246 Average loss: 0.000116
====> Test set loss: 0.000227
====> Epoch: 247 Average loss: 0.000110
====> Test set loss: 0.000230
====> Epoch: 248 Average loss: 0.000114
====> Test set loss: 0.000235
====> Epoch: 249 Average loss: 0.000118
====> Test set loss: 0.000224
====> Epoch: 250 Average loss: 0.000111
====> Test set loss: 0.000235
====> Epoch: 251 Average loss: 0.000155
====> Test set loss: 0.000269
====> Epoch: 252 Average loss: 0.000118
====> Test set loss: 0.000242
====> Epoch: 253 Ave

====> Epoch: 265 Average loss: 0.000106
====> Test set loss: 0.000222
====> Epoch: 266 Average loss: 0.000110
====> Test set loss: 0.000226
====> Epoch: 267 Average loss: 0.000108
====> Test set loss: 0.000231
====> Epoch: 268 Average loss: 0.000120
====> Test set loss: 0.000235
====> Epoch: 269 Average loss: 0.000110
====> Test set loss: 0.000250
====> Epoch: 270 Average loss: 0.000109
====> Test set loss: 0.000219
====> Epoch: 271 Average loss: 0.000109
====> Test set loss: 0.000255
====> Epoch: 272 Average loss: 0.000112
====> Test set loss: 0.000244
====> Epoch: 273 Average loss: 0.000110
====> Test set loss: 0.000233
====> Epoch: 274 Average loss: 0.000103
====> Test set loss: 0.000229
====> Epoch: 275 Average loss: 0.000103
====> Test set loss: 0.000236
====> Epoch: 276 Average loss: 0.000107
====> Test set loss: 0.000233
====> Epoch: 277 Average loss: 0.000111
====> Test set loss: 0.000253
====> Epoch: 278 Average loss: 0.000113
====> Test set loss: 0.000228
====> Epoch: 279 Ave

====> Test set loss: 0.000243
====> Epoch: 292 Average loss: 0.000105
====> Test set loss: 0.000217
====> Epoch: 293 Average loss: 0.000100
====> Test set loss: 0.000246
====> Epoch: 294 Average loss: 0.000105
====> Test set loss: 0.000221
====> Epoch: 295 Average loss: 0.000102
====> Test set loss: 0.000224
====> Epoch: 296 Average loss: 0.000103
====> Test set loss: 0.000219
====> Epoch: 297 Average loss: 0.000102
====> Test set loss: 0.000228
====> Epoch: 298 Average loss: 0.000102
====> Test set loss: 0.000234
====> Epoch: 299 Average loss: 0.000102
====> Test set loss: 0.000224
====> Epoch: 300 Average loss: 0.000103
====> Test set loss: 0.000237
====> Epoch: 301 Average loss: 0.000104
====> Test set loss: 0.000243
====> Epoch: 302 Average loss: 0.000107
====> Test set loss: 0.000232
====> Epoch: 303 Average loss: 0.000104
====> Test set loss: 0.000243
====> Epoch: 304 Average loss: 0.000104
====> Test set loss: 0.000225
====> Epoch: 305 Average loss: 0.000107
====> Test set loss:

====> Epoch: 318 Average loss: 0.000100
====> Test set loss: 0.000249
====> Epoch: 319 Average loss: 0.000100
====> Test set loss: 0.000251
====> Epoch: 320 Average loss: 0.000104
====> Test set loss: 0.000214
====> Epoch: 321 Average loss: 0.000098
====> Test set loss: 0.000220
====> Epoch: 322 Average loss: 0.000105
====> Test set loss: 0.000238
====> Epoch: 323 Average loss: 0.000103
====> Test set loss: 0.000237
====> Epoch: 324 Average loss: 0.000100
====> Test set loss: 0.000224
====> Epoch: 325 Average loss: 0.000101
====> Test set loss: 0.000225
====> Epoch: 326 Average loss: 0.000102
====> Test set loss: 0.000248
====> Epoch: 327 Average loss: 0.000106
====> Test set loss: 0.000221
====> Epoch: 328 Average loss: 0.000104
====> Test set loss: 0.000237
====> Epoch: 329 Average loss: 0.000106
====> Test set loss: 0.000222
====> Epoch: 330 Average loss: 0.000102
====> Test set loss: 0.000225
====> Epoch: 331 Average loss: 0.000099
====> Test set loss: 0.000225
====> Epoch: 332 Ave

====> Epoch: 344 Average loss: 0.000104
====> Test set loss: 0.000234
====> Epoch: 345 Average loss: 0.000102
====> Test set loss: 0.000239
====> Epoch: 346 Average loss: 0.000106
====> Test set loss: 0.000229
====> Epoch: 347 Average loss: 0.000100
====> Test set loss: 0.000236
====> Epoch: 348 Average loss: 0.000103
====> Test set loss: 0.000218
====> Epoch: 349 Average loss: 0.000103
====> Test set loss: 0.000234
====> Epoch: 350 Average loss: 0.000100
====> Test set loss: 0.000228
====> Epoch: 351 Average loss: 0.000098
====> Test set loss: 0.000243
====> Epoch: 352 Average loss: 0.000096
====> Test set loss: 0.000226
====> Epoch: 353 Average loss: 0.000095
====> Test set loss: 0.000224
====> Epoch: 354 Average loss: 0.000095
====> Test set loss: 0.000221
====> Epoch: 355 Average loss: 0.000096
====> Test set loss: 0.000244
====> Epoch: 356 Average loss: 0.000098
====> Test set loss: 0.000216
====> Epoch: 357 Average loss: 0.000095
====> Test set loss: 0.000230
====> Epoch: 358 Ave

====> Epoch: 370 Average loss: 0.000094
====> Test set loss: 0.000231
====> Epoch: 371 Average loss: 0.000099
====> Test set loss: 0.000247
====> Epoch: 372 Average loss: 0.000096
====> Test set loss: 0.000238
====> Epoch: 373 Average loss: 0.000096
====> Test set loss: 0.000232
====> Epoch: 374 Average loss: 0.000096
====> Test set loss: 0.000213
====> Epoch: 375 Average loss: 0.000102
====> Test set loss: 0.000265
====> Epoch: 376 Average loss: 0.000100
====> Test set loss: 0.000222
====> Epoch: 377 Average loss: 0.000092
====> Test set loss: 0.000233
====> Epoch: 378 Average loss: 0.000095
====> Test set loss: 0.000212
====> Epoch: 379 Average loss: 0.000101
====> Test set loss: 0.000222
====> Epoch: 380 Average loss: 0.000097
====> Test set loss: 0.000229
====> Epoch: 381 Average loss: 0.000099
====> Test set loss: 0.000224
====> Epoch: 382 Average loss: 0.000097
====> Test set loss: 0.000236
====> Epoch: 383 Average loss: 0.000096
====> Test set loss: 0.000236
====> Epoch: 384 Ave

====> Epoch: 397 Average loss: 0.000102
====> Test set loss: 0.000237
====> Epoch: 398 Average loss: 0.000095
====> Test set loss: 0.000230
====> Epoch: 399 Average loss: 0.000095
====> Test set loss: 0.000211
====> Epoch: 400 Average loss: 0.000092
====> Test set loss: 0.000230
====> Epoch: 401 Average loss: 0.000099
====> Test set loss: 0.000237
====> Epoch: 402 Average loss: 0.000091
====> Test set loss: 0.000225
====> Epoch: 403 Average loss: 0.000097
====> Test set loss: 0.000250
====> Epoch: 404 Average loss: 0.000092
====> Test set loss: 0.000214
====> Epoch: 405 Average loss: 0.000096
====> Test set loss: 0.000245
====> Epoch: 406 Average loss: 0.000092
====> Test set loss: 0.000231
====> Epoch: 407 Average loss: 0.000091
====> Test set loss: 0.000227
====> Epoch: 408 Average loss: 0.000091
====> Test set loss: 0.000231
====> Epoch: 409 Average loss: 0.000097
====> Test set loss: 0.000228
====> Epoch: 410 Average loss: 0.000092
====> Test set loss: 0.000225
====> Epoch: 411 Ave

====> Epoch: 423 Average loss: 0.000092
====> Test set loss: 0.000229
====> Epoch: 424 Average loss: 0.000089
====> Test set loss: 0.000240
====> Epoch: 425 Average loss: 0.000088
====> Test set loss: 0.000237
====> Epoch: 426 Average loss: 0.000094
====> Test set loss: 0.000232
====> Epoch: 427 Average loss: 0.000094
====> Test set loss: 0.000228
====> Epoch: 428 Average loss: 0.000091
====> Test set loss: 0.000201
====> Epoch: 429 Average loss: 0.000090
====> Test set loss: 0.000213
====> Epoch: 430 Average loss: 0.000092
====> Test set loss: 0.000222
====> Epoch: 431 Average loss: 0.000090
====> Test set loss: 0.000234
====> Epoch: 432 Average loss: 0.000092
====> Test set loss: 0.000244
====> Epoch: 433 Average loss: 0.000092
====> Test set loss: 0.000222
====> Epoch: 434 Average loss: 0.000089
====> Test set loss: 0.000232
====> Epoch: 435 Average loss: 0.000093
====> Test set loss: 0.000237
====> Epoch: 436 Average loss: 0.000096
====> Test set loss: 0.000239
====> Epoch: 437 Ave

====> Epoch: 449 Average loss: 0.000086
====> Test set loss: 0.000207
====> Epoch: 450 Average loss: 0.000086
====> Test set loss: 0.000232
====> Epoch: 451 Average loss: 0.000090
====> Test set loss: 0.000205
====> Epoch: 452 Average loss: 0.000091
====> Test set loss: 0.000226
====> Epoch: 453 Average loss: 0.000087
====> Test set loss: 0.000216
====> Epoch: 454 Average loss: 0.000091
====> Test set loss: 0.000237
====> Epoch: 455 Average loss: 0.000088
====> Test set loss: 0.000221
====> Epoch: 456 Average loss: 0.000091
====> Test set loss: 0.000226
====> Epoch: 457 Average loss: 0.000101
====> Test set loss: 0.000208
====> Epoch: 458 Average loss: 0.000090
====> Test set loss: 0.000228
====> Epoch: 459 Average loss: 0.000089
====> Test set loss: 0.000219
====> Epoch: 460 Average loss: 0.000103
====> Test set loss: 0.000244
====> Epoch: 461 Average loss: 0.000098
====> Test set loss: 0.000229
====> Epoch: 462 Average loss: 0.000096
====> Test set loss: 0.000250
====> Epoch: 463 Ave

====> Test set loss: 0.000220
====> Epoch: 476 Average loss: 0.000090
====> Test set loss: 0.000228
====> Epoch: 477 Average loss: 0.000090
====> Test set loss: 0.000226
====> Epoch: 478 Average loss: 0.000090
====> Test set loss: 0.000216
====> Epoch: 479 Average loss: 0.000095
====> Test set loss: 0.000230
====> Epoch: 480 Average loss: 0.000087
====> Test set loss: 0.000236
====> Epoch: 481 Average loss: 0.000092
====> Test set loss: 0.000203
====> Epoch: 482 Average loss: 0.000089
====> Test set loss: 0.000214
====> Epoch: 483 Average loss: 0.000089
====> Test set loss: 0.000227
====> Epoch: 484 Average loss: 0.000086
====> Test set loss: 0.000215
====> Epoch: 485 Average loss: 0.000085
====> Test set loss: 0.000229
====> Epoch: 486 Average loss: 0.000083
====> Test set loss: 0.000246
====> Epoch: 487 Average loss: 0.000087
====> Test set loss: 0.000229
====> Epoch: 488 Average loss: 0.000084
====> Test set loss: 0.000230
====> Epoch: 489 Average loss: 0.000085
====> Test set loss: