In [1]:
import torch
import torchvision
from torch import nn
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import trange, tqdm
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image

In [2]:
class lfw(Dataset):
    def __init__(self, base_path: str = './data/', photo_length: int = 112, subtract_mean: bool = False, divide_std: bool = False, device: str = 'cpu'): #, divide_255: bool = True
        self.cropped_file_path = base_path + 'cropped/'
        self.pixelated_file_path = base_path + 'pixelated/'
        self.blurred_file_path = base_path + 'blurred/'
        # self.tensor_file_path = base_path + 'tensor/'

        self.n_files = len(os.listdir(self.cropped_file_path))
        self.photo_length = photo_length
        
        self.cropped = torch.empty((self.n_files, 3, self.photo_length, self.photo_length)) # torchvision mette i canali rgb al primo posto
        self.pixelated = torch.empty((self.n_files, 3, self.photo_length, self.photo_length))
        self.blurred = torch.empty((self.n_files, 3, self.photo_length, self.photo_length))

        # if os.listdir(self.tensor_file_path) == []:
        for i in range(self.n_files):
            file_name = f'{i}.jpg'
            self.cropped[i] = torchvision.io.read_image(self.cropped_file_path + file_name)
            self.pixelated[i] = torchvision.io.read_image(self.pixelated_file_path + file_name)
            self.blurred[i] = torchvision.io.read_image(self.blurred_file_path + file_name)

            # torch.save(self.cropped, self.tensor_file_path + 'cropped.pt')
            # torch.save(self.pixelated, self.tensor_file_path + 'pixelated.pt')
            # torch.save(self.blurred, self.tensor_file_path + 'blurred.pt')

        # else:
        #     self.cropped = torch.load(self.tensor_file_path + 'cropped.pt')
        #     self.pixelated = torch.load(self.tensor_file_path + 'pixelated.pt')
        #     self.blurred = torch.load(self.tensor_file_path + 'blurred.pt')

        if subtract_mean:
            self.cropped -= self.cropped.mean()
            self.pixelated -= self.pixelated.mean()
            self.blurred -= self.blurred.mean()
        
        if divide_std:
            self.cropped /= self.cropped.std()
            self.pixelated /= self.pixelated.std()
            self.blurred /= self.blurred.std()

        self.cropped /= 255
        self.pixelated /= 255
        self.blurred /= 255

        
        if self.device == 'cuda':
            self.cropped, self.pixelated, self.blurred = self.cropped.to('cuda'), self.pixelated.to('cuda'), self.blurred.to('cuda')
            self.device = device
        else:
            self.device = 'cpu'

    def __getitem__(self, index):
        return self.cropped[index], self.pixelated[index], self.blurred[index]

    def __len__(self):
        return self.n_files

In [3]:
# d = lfw() # 2,5 - 3 minuti circa, 5-6 giga di ram occupati (senza cuda)

In [4]:
# d[6116][0].shape 3*112*112
# per mettere il canale alla fine anziché all'inizio come richiesto ad es. da matplotlib si può usare np.transpose con l'argomento axes = (1, 2, 0)

https://stackoverflow.com/questions/23943379/swapping-the-dimensions-of-a-numpy-array

In [5]:
# torch.save(d.cropped, './data/tensor/cropped.pt')
# torch.save(d.pixelated, './data/tensor/pixelated.pt')
# torch.save(d.blurred, './data/tensor/blurred.pt')
# impiegano circa 20 secondi a testa

In [6]:
# d2 = lfw() # 2 e mezzo senza tensori, poco meno di 2 minuti con. Tanto vale risparmiare 5,56 gb di memoria...

In [7]:
class lfw_b(Dataset):
    def __init__(self, train: bool, training_frac = 0.8, base_path: str = './data/', photo_length: int = 112,
                 subtract_mean: bool = False, divide_std: bool = False, device: str = 'cpu'): #, divide_255: bool = True
        self.cropped_file_path = base_path + 'cropped/'
        # self.pixelated_file_path = base_path + 'pixelated/'
        self.blurred_file_path = base_path + 'blurred/'
        # self.tensor_file_path = base_path + 'tensor/'

        self.n_files = len(os.listdir(self.cropped_file_path)) # numero globale
        self.photo_length = photo_length

        train_index  = int(training_frac * self.n_files)
        self.train = train

        if self.train:
            self.n_files = train_index # solo train
            faces_indices = range(train_index)
        else:
            faces_indices = range(train_index, self.n_files)
            self.n_files = self.n_files - train_index # solo test ma dopo che ho usato quello globale la riga precedente
        
        self.cropped = torch.empty((self.n_files, 3, self.photo_length, self.photo_length)) # torchvision mette i canali rgb al primo posto
        # self.pixelated = torch.empty((self.n_files, 3, self.photo_length, self.photo_length))
        self.blurred = torch.empty((self.n_files, 3, self.photo_length, self.photo_length))

        # faces_indices = range(train_index) if self.train else range(train_index, self.n_files)

        for i, j in enumerate(faces_indices):
            file_name = f'{j}.jpg'
            self.cropped[i] = torchvision.io.read_image(self.cropped_file_path + file_name)
            # self.pixelated[i] = torchvision.io.read_image(self.pixelated_file_path + file_name)
            self.blurred[i] = torchvision.io.read_image(self.blurred_file_path + file_name)

        if subtract_mean:
            self.cropped -= self.cropped.mean()
            # self.pixelated -= self.pixelated.mean()
            self.blurred -= self.blurred.mean()
        
        if divide_std:
            self.cropped /= self.cropped.std()
            # self.pixelated /= self.pixelated.std()
            self.blurred /= self.blurred.std()

        self.cropped /= 255
        # self.pixelated /= 255
        self.blurred /= 255

        
        if device == 'cuda':
            # self.cropped, self.pixelated, self.blurred = self.cropped.to('cuda'), self.pixelated.to('cuda'), self.blurred.to('cuda')
            self.cropped, self.blurred = self.cropped.to('cuda'), self.blurred.to('cuda')
            self.device = 'cuda'
        else:
            self.device = 'cpu'

    def __getitem__(self, index):
        # return self.cropped[index], self.pixelated[index], self.blurred[index]
        return self.cropped[index], self.blurred[index]


    def __len__(self):
        return self.n_files

In [8]:
#blur = lfw_b() # cuda + solo blurred: 2 minuti e qualcosa?

In [9]:
#blur.blurred.shape # 13233, 3, 112, 112

In [10]:
training_set = lfw_b(train = True)

In [11]:
training_set.cropped.shape

torch.Size([10586, 3, 112, 112])

In [12]:
int(training_set.n_files*0.8)

8468

In [13]:
# test_set = lfw_b(train = False)

In [14]:
# test_set.cropped.shape

In [15]:
class lfw_tensor(Dataset):
    def __init__(self, cropped, pixelated, blurred):
        self.cropped, self.pixelated, self.blurred = cropped, pixelated, blurred

    def __getitem__(self, index):
        return self.cropped[index], self.pixelated[index], self.blurred[index]

    def __len__(self):
        return self.cropped.shape[0]


In [16]:
class lfw_tensor_b(Dataset):
    def __init__(self, cropped, blurred):
        self.cropped, self.blurred = cropped, blurred

    def __getitem__(self, index):
        return self.cropped[index], self.blurred[index]

    def __len__(self):
        return self.cropped.shape[0]

In [17]:
# helper function to create dataloaders from the above datasets in a grid-search friendly way, as batch_size is an argument
def create_train_val_dataloaders(dataset: lfw, batch_size: int, training_frac = 0.8, num_workers: int = 0, seed: int = 1234):
    np.random.seed(seed)
    n_samples = len(dataset)
    train_indices = np.random.choice(n_samples, size = int(training_frac * n_samples), replace = False)
    val_indices   = np.setdiff1d(np.arange(n_samples), train_indices)

    train_partial_dataset = lfw_tensor(dataset.cropped[train_indices], dataset.pixelated[train_indices], dataset.blurred[train_indices])
    val_dataset = lfw_tensor(dataset.cropped[val_indices], dataset.pixelated[val_indices], dataset.blurred[val_indices])
    
    torch.random.manual_seed(seed)
    torch.cuda.random.manual_seed(seed)
    train_dataloader = DataLoader(train_partial_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)
    val_dataloader   = DataLoader(val_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers)
    return train_dataloader, val_dataloader

In [18]:
def create_train_val_dataloaders_b(dataset: lfw, batch_size: int, training_frac = 0.8, num_workers: int = 0, seed: int = 1234):
    np.random.seed(seed)
    n_samples = len(dataset)
    train_indices = np.random.choice(n_samples, size = int(training_frac * n_samples), replace = False)
    val_indices   = np.setdiff1d(np.arange(n_samples), train_indices)

    train_partial_dataset = lfw_tensor_b(dataset.cropped[train_indices], dataset.blurred[train_indices])
    val_dataset = lfw_tensor_b(dataset.cropped[val_indices], dataset.blurred[val_indices])
    
    torch.random.manual_seed(seed)
    torch.cuda.random.manual_seed(seed)
    train_dataloader = DataLoader(train_partial_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)
    val_dataloader   = DataLoader(val_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers)
    return train_dataloader, val_dataloader

In [19]:
# dividi train e test

In [73]:
class VarAutoEncoder(nn.Module):
    def __init__(self, latent_space_dim: int, activation_function: str, optimizer: str,
                initial_lr: float, batch_size: int, dropout: float = None, l1_reg_strength: float = None, l2_reg_strength: float = None,
                loss_fn = nn.MSELoss(reduction = 'sum'), device = 'cuda'):
        super().__init__()
        self.latent_space_dim = latent_space_dim
        self.activation = getattr(nn, activation_function)()
        self.device = device
        self.dropout = dropout

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size = 4, stride = 3, padding = 1), #out = (8, 14, 14)
            self.activation,
            nn.Dropout2d(self.dropout) if dropout is not None else nn.Identity(),
            nn.Conv2d(8, 16, kernel_size = 3, stride = 2, padding = 1), #out = (16, 7, 7)
            self.activation,
            nn.Dropout(self.dropout) if dropout is not None else nn.Identity(),
            nn.Conv2d(16, 32, kernel_size = 3, stride = 2, padding = 0), #out = (32, 3, 3)
            self.activation,
            nn.Dropout(self.dropout) if dropout is not None else nn.Identity(),
            nn.Conv2d(32, 32, kernel_size = 3, stride = 2, padding = 0), #out = (32, 3, 3)
            self.activation,
            nn.Dropout(self.dropout) if dropout is not None else nn.Identity(),
            nn.Flatten(start_dim = 1)
        )
        # 288 = 3*3*32
        self.decoder = nn.Sequential( 
            nn.Linear(self.latent_space_dim, 112),
            self.activation,
            nn.Dropout(self.dropout) if dropout is not None else nn.Identity(),
            nn.Linear(112, 512), # 288
            self.activation,
            nn.Dropout(self.dropout) if dropout is not None else nn.Identity(),
            nn.Unflatten(dim = 1, unflattened_size = (32, 4, 4)),
            nn.ConvTranspose2d(32, 32, kernel_size = 3, stride = 2, output_padding = 0),
            self.activation,
            nn.Dropout(self.dropout) if dropout is not None else nn.Identity(),
            nn.ConvTranspose2d(32, 16, kernel_size = 3, stride = 2, output_padding = 0),
            self.activation,
            nn.Dropout(self.dropout) if dropout is not None else nn.Identity(),
            nn.ConvTranspose2d(16, 8, kernel_size = 3, stride = 2, padding = 1, output_padding = 0),
            self.activation,
            nn.Dropout2d(self.dropout) if dropout is not None else nn.Identity(),
            nn.ConvTranspose2d(8, 3, kernel_size = 4, stride = 3, padding = 1, output_padding = 2),
            nn.Sigmoid() #-> [0,1] output
        )

        self.avg = nn.Sequential( # Predicts the means of a MVN dist.
            nn.Linear(512, 112), # 288
            self.activation,
            nn.Dropout(self.dropout) if dropout is not None else nn.Identity(),
            nn.Linear(112, self.latent_space_dim)
        )

        self.log_var = nn.Sequential( # Predicts the (log) variances of an uncorrelated MVN (i.e. log of cov. matrix diagonal)
            nn.Linear(512, 112), # 288
            self.activation,
            nn.Dropout(self.dropout) if dropout is not None else nn.Identity(),
            nn.Linear(112, self.latent_space_dim)
        )

        if optimizer == 'SGD': # if optimizer is SGD we add Nesterov momentum with a 0.9 fairly standard value
            self.optimizer = torch.optim.SGD(params = self.parameters(), momentum = 0.9, nesterov = True, lr = initial_lr)
        else:
            self.optimizer = getattr(torch.optim, optimizer)(params = self.parameters(), lr = initial_lr)
        
        self.batch_size = batch_size
        self.loss_fn = loss_fn
        self.l1_reg_strength = l1_reg_strength
        self.l2_reg_strength = l2_reg_strength

        if l1_reg_strength is None and l2_reg_strength is None:
            self.reg = lambda: torch.tensor([0], dtype = float, requires_grad = True).to(self.device)
        if l1_reg_strength is not None:
            self.reg = lambda: (sum(((w.abs()).sum() for w in self.parameters()))*self.l1_reg_strength).to(self.device)
        if l2_reg_strength is not None:
            self.reg = lambda: (sum(((w**2).sum() for w in self.parameters()))*self.l2_reg_strength).to(self.device)

    def sample_in_latent_space(self, mu, log_var): # sampling from MVN in latent space with provided mean and variances
        sigma = torch.exp(0.5 * log_var) #log_var is log(variance) = log(sigma**2) = 2 * log(sigma)
        return mu + torch.randn_like(mu) * sigma
        #Var(aX) = a**2 Var(X), so we need to multiply pred_sqrt_var = sigma (square root of the variance) by the standard normal distribution

    def forward(self, x):
        internal_repr  = self.encoder(x) # conv. part --> produces the internal representation
        pred_means     = self.avg(internal_repr) # linear part 1 --> predicts means
        pred_log_var   = self.log_var(internal_repr) # linear part 2 --> predicts variances
        sample         = self.sample_in_latent_space(mu = pred_means, log_var = pred_log_var) # sample in latent space
        decoded_sample = self.decoder(sample) # decoder --> produces final output
        
        return decoded_sample, pred_means, pred_log_var # we need to pass means and vars forward too, since they're needed to compute the KL div. term in the loss

    def train_single_epoch(self, train_dataloader, verbose_single_epoch: bool = True): # function to be iterated inside the fit loop (no cv)
        self.train() # enable dropout etc.
        tr_err_single_epoch = 0
        # SOLO BLURRED O PIXELATED
        loading = tqdm(train_dataloader) if verbose_single_epoch else train_dataloader
        for _, x_batch in loading: # ignore y_batch, we don't need the labels
            output, mu, log_var = self(x_batch)
            loss                = self.loss_fn(output, x_batch) + self.reg() - 0.5 * torch.sum(1. + log_var - mu**2 - torch.exp(log_var))
            loss /= len(x_batch)
            self.optimizer.zero_grad() # reset gradients
            loss.backward() # backpropagation
            self.optimizer.step() # update weights
            tr_err_single_epoch += loss.detach().cpu().numpy() # save current training error
        
        tr_err_single_epoch /= len(train_dataloader) # len of a dataloader = n. of batches
        return tr_err_single_epoch

    @torch.no_grad()
    def val_single_epoch(self, val_dataloader):
        self.eval() # disable dropout etc.
        val_err_single_epoch = 0

        for _, x_batch in val_dataloader: # ignore y_batch, we don't need the labels
            output, mu, log_var   = self(x_batch)
            loss                  = nn.MSELoss()(output, x_batch) # we compare the reconstructed output with the original input
            val_err_single_epoch += loss.detach().cpu().numpy()
        
        val_err_single_epoch /= len(val_dataloader)
        return val_err_single_epoch

    def fit(self, training_dataset, max_n_iter: int = 25, min_n_iter: int = 3, patience: int = 4, tol: float = 0.0001, seed: int = 1234, verbose: bool = True, num_workers: int = 0):
        if seed is not None: # default seed to ensure reproducibility
            torch.random.manual_seed(seed)

        training_dataloader, validation_dataloader = create_train_val_dataloaders_b(training_dataset, batch_size = self.batch_size, num_workers = num_workers) # dependent on self.batch_size

        self.training_error_history, self.val_error_history = np.zeros(max_n_iter), np.zeros(max_n_iter) # no CV --> no average over dataloaders --> we have a single value to use as an estimate of tr./val. errors
        self.best_val_error = np.Inf
        patience_counter = 0

        loading = trange(max_n_iter) # if verbose else range(max_n_iter) # useful toggle during training/debugging
        for epoch in loading:
            if verbose:
                print(f'Training {epoch=}') # useful to track training
            tr_err  = self.training_error_history[epoch] = self.train_single_epoch(train_dataloader = training_dataloader)
            val_err = self.val_error_history[epoch] = self.val_single_epoch(val_dataloader = validation_dataloader)
            if verbose:
                print(f'{epoch=}: {tr_err=}, {val_err=}') # useful to track training

            if epoch > min_n_iter:
                if val_err > self.best_val_error + tol:
                    patience_counter += 1
                else:
                    self.best_val_error = val_err
                    patience_counter = 0
                if patience_counter > patience:
                    break
        n_executed_epochs = epoch + 1 # + 1 to go from 0, N-1 (python) to 1, N (human readable) counting scheme 
        if n_executed_epochs < max_n_iter:
            self.best_val_error_epoch = n_executed_epochs - patience_counter
        else:
            self.best_val_error_epoch = max_n_iter - patience_counter

    @torch.no_grad()
    def predict(self, sample):
        self.eval()
        if sample.shape == torch.Size([3, 112, 112]): # a single sample # torch.Size([1, 28, 28]) in origine (-> serve la dim. 0 = N)
            output = self(sample.unsqueeze(0))[0] # the forward pass computes means and variances too, but here we only want the reconstructed samples
        else:
            output = self(sample)[0]
        return output

    def test_accuracy(self, test_dataset):
        original = test_dataset.cropped
        try:
            corrupted = test_dataset.blurred
        except:
            corrupted = test_dataset.pixelated
        return float(nn.MSELoss()(self.predict(corrupted), original))

    def plot_original_vs_reconstructed_sample(self, sample, figsize = (7,7), return_array: bool = False):
        fig, ax = plt.subplots(nrows = 1, ncols = 2, figsize = figsize)
        img = sample
        # img_rec = self.predict(img).cpu().detach().squeeze(0).squeeze(0).numpy() # one to eliminate N_samples, one to eliminate N_channels
        img_rec = self.predict(img).cpu().detach().squeeze(0).numpy().transpose((1, 2, 0))
        # img = img.cpu().detach().squeeze(0).numpy()
        img = img.cpu().detach().numpy().transpose((1, 2, 0))

        ax[0].imshow(img, cmap = 'gray')
        ax[0].set_title('Orig.')
        ax[0].axis('off')
        ax[1].imshow(img_rec, cmap = 'gray')
        ax[1].set_title('Rec.')
        ax[1].axis('off')

        if return_array:
            return img_rec
        else:
            return fig, ax

    
    def plot_reconstructed_samples(self, dataset, nrows: int = 3, ncols: int = 3, figsize = (42, 21)):
        indices = np.arange(nrows*ncols)
        idx = 0
        ncols = 2*ncols
        fig, ax = plt.subplots(nrows = nrows, ncols = ncols, figsize = figsize)
        
        for i in range(nrows):
            for j in range(0, ncols, 2):
                img = dataset.data[indices[idx]]
                img_rec = self.plot_original_vs_reconstructed_sample(img, return_array = True)
                # img = img.detach().cpu().squeeze(0).numpy()
                img = img.detach().cpu().numpy().transpose((1, 2, 0))
                ax[i,j].imshow(img, cmap = 'gray')
                ax[i,j].set_title('Orig.')
                ax[i,j].axis('off')
                ax[i,j+1].imshow(img_rec, cmap = 'gray')
                ax[i,j+1].set_title('Rec.')
                ax[i,j+1].axis('off')
                idx += 1

        return fig, ax

In [74]:
# training_set.cropped = training_set.cropped.to('cpu')
# training_set.blurred = training_set.blurred.to('cpu')

In [75]:
model = VarAutoEncoder(64, 'ReLU', 'Adam', 1e-3, 128)#.cuda()

In [60]:
training_set.cropped[0].shape

torch.Size([3, 112, 112])

In [61]:
a = training_set.cropped[0].unsqueeze(0)

In [62]:
type(model(a)[0])

torch.Tensor

In [65]:
z = torch.empty((1, 32, 4, 4))

In [72]:
z1 = nn.ConvTranspose2d(32, 32, kernel_size = 3, stride = 2, output_padding = 0)(z)
print(z1.shape)
z2 = nn.ConvTranspose2d(32, 16, kernel_size = 3, stride = 2, output_padding = 0)(z1)
print(z2.shape)
z3 = nn.ConvTranspose2d(16, 8, kernel_size = 3, stride = 2, padding = 1, output_padding = 0)(z2)
print(z3.shape)
z4 = nn.ConvTranspose2d(8, 3, kernel_size = 4, stride = 3, padding = 1, output_padding = 2)(z3)
print(z4.shape)

torch.Size([1, 32, 9, 9])
torch.Size([1, 16, 19, 19])
torch.Size([1, 8, 37, 37])
torch.Size([1, 3, 112, 112])


In [None]:
nn.ConvTranspose2d(32, 16, kernel_size = 3, stride = 2, output_padding = 0),
nn.ConvTranspose2d(16, 8, kernel_size = 3, stride = 2, padding = 1, output_padding = 0),
nn.ConvTranspose2d(8, 3, kernel_size = 3, stride = 2, padding = 0, output_padding = 0)

In [64]:
model.encoder(a).shape

torch.Size([1, 512])

In [63]:
model(a)[0].shape

torch.Size([1, 3, 75, 75])

In [76]:
model.fit(training_set)

  0%|          | 0/25 [00:00<?, ?it/s]

Training epoch=0


100%|██████████| 67/67 [00:21<00:00,  3.10it/s]
  4%|▍         | 1/25 [00:23<09:26, 23.60s/it]

epoch=0: tr_err=array([1624.54515719]), val_err=0.03584180465515922
Training epoch=1


100%|██████████| 67/67 [00:18<00:00,  3.64it/s]
  8%|▊         | 2/25 [00:43<08:19, 21.70s/it]

epoch=1: tr_err=array([1231.99766686]), val_err=0.028168122119763318
Training epoch=2


100%|██████████| 67/67 [00:18<00:00,  3.63it/s]
 12%|█▏        | 3/25 [01:04<07:44, 21.11s/it]

epoch=2: tr_err=array([860.44919476]), val_err=0.01658227782258216
Training epoch=3


100%|██████████| 67/67 [00:18<00:00,  3.66it/s]
 16%|█▌        | 4/25 [01:24<07:16, 20.77s/it]

epoch=3: tr_err=array([585.69971928]), val_err=0.014921854743186165
Training epoch=4


100%|██████████| 67/67 [00:18<00:00,  3.63it/s]
 20%|██        | 5/25 [01:45<06:52, 20.63s/it]

epoch=4: tr_err=array([517.06675769]), val_err=0.011850555372588775
Training epoch=5


100%|██████████| 67/67 [00:18<00:00,  3.58it/s]
 24%|██▍       | 6/25 [02:05<06:32, 20.65s/it]

epoch=5: tr_err=array([466.2190589]), val_err=0.01064763470169376
Training epoch=6


100%|██████████| 67/67 [00:18<00:00,  3.65it/s]
 28%|██▊       | 7/25 [02:26<06:10, 20.56s/it]

epoch=6: tr_err=array([440.99903326]), val_err=0.010193193933981307
Training epoch=7


100%|██████████| 67/67 [00:18<00:00,  3.59it/s]
 32%|███▏      | 8/25 [02:46<05:50, 20.62s/it]

epoch=7: tr_err=array([425.03187214]), val_err=0.009747018106281757
Training epoch=8


100%|██████████| 67/67 [00:18<00:00,  3.59it/s]
 36%|███▌      | 9/25 [03:07<05:30, 20.63s/it]

epoch=8: tr_err=array([407.0350564]), val_err=0.009530601749087082
Training epoch=9


100%|██████████| 67/67 [00:18<00:00,  3.56it/s]
 40%|████      | 10/25 [03:28<05:10, 20.69s/it]

epoch=9: tr_err=array([391.99341701]), val_err=0.009074156556059332
Training epoch=10


100%|██████████| 67/67 [00:18<00:00,  3.54it/s]
 44%|████▍     | 11/25 [03:49<04:50, 20.76s/it]

epoch=10: tr_err=array([376.18165498]), val_err=0.00836055654594127
Training epoch=11


100%|██████████| 67/67 [00:19<00:00,  3.50it/s]
 48%|████▊     | 12/25 [04:10<04:31, 20.87s/it]

epoch=11: tr_err=array([358.99870939]), val_err=0.008214823600343046
Training epoch=12


100%|██████████| 67/67 [00:18<00:00,  3.57it/s]
 52%|█████▏    | 13/25 [04:31<04:10, 20.84s/it]

epoch=12: tr_err=array([352.69969041]), val_err=0.007997550672906287
Training epoch=13


100%|██████████| 67/67 [00:19<00:00,  3.52it/s]
 56%|█████▌    | 14/25 [04:52<03:49, 20.89s/it]

epoch=13: tr_err=array([342.11247726]), val_err=0.00813441605800215
Training epoch=14


100%|██████████| 67/67 [00:18<00:00,  3.60it/s]
 60%|██████    | 15/25 [05:12<03:27, 20.80s/it]

epoch=14: tr_err=array([336.44464479]), val_err=0.007610452005310971
Training epoch=15


100%|██████████| 67/67 [00:18<00:00,  3.56it/s]
 64%|██████▍   | 16/25 [05:33<03:07, 20.82s/it]

epoch=15: tr_err=array([330.00454987]), val_err=0.0075379093451535
Training epoch=16


100%|██████████| 67/67 [00:18<00:00,  3.59it/s]
 68%|██████▊   | 17/25 [05:54<02:46, 20.79s/it]

epoch=16: tr_err=array([325.37611954]), val_err=0.007394536672269597
Training epoch=17


100%|██████████| 67/67 [00:19<00:00,  3.50it/s]
 72%|███████▏  | 18/25 [06:15<02:26, 20.88s/it]

epoch=17: tr_err=array([320.58769783]), val_err=0.007143169093657942
Training epoch=18


100%|██████████| 67/67 [00:19<00:00,  3.50it/s]
 76%|███████▌  | 19/25 [06:36<02:05, 20.98s/it]

epoch=18: tr_err=array([312.22545547]), val_err=0.007143096833982889
Training epoch=19


100%|██████████| 67/67 [00:19<00:00,  3.44it/s]
 80%|████████  | 20/25 [06:58<01:45, 21.15s/it]

epoch=19: tr_err=array([310.88268111]), val_err=0.006940418580437408
Training epoch=20


100%|██████████| 67/67 [00:19<00:00,  3.38it/s]
 84%|████████▍ | 21/25 [07:19<01:25, 21.35s/it]

epoch=20: tr_err=array([303.15243232]), val_err=0.0069327530560686305
Training epoch=21


100%|██████████| 67/67 [00:20<00:00,  3.29it/s]
 88%|████████▊ | 22/25 [07:42<01:04, 21.65s/it]

epoch=21: tr_err=array([296.93699409]), val_err=0.006886642481036046
Training epoch=22


100%|██████████| 67/67 [00:19<00:00,  3.48it/s]
 92%|█████████▏| 23/25 [08:03<00:43, 21.53s/it]

epoch=22: tr_err=array([295.02754308]), val_err=0.0066125574795638815
Training epoch=23


100%|██████████| 67/67 [00:19<00:00,  3.47it/s]
 96%|█████████▌| 24/25 [08:24<00:21, 21.48s/it]

epoch=23: tr_err=array([290.21534734]), val_err=0.006385253687553546
Training epoch=24


100%|██████████| 67/67 [00:19<00:00,  3.46it/s]
100%|██████████| 25/25 [08:46<00:00, 21.06s/it]

epoch=24: tr_err=array([285.81497525]), val_err=0.006389781312250039



