<a href="https://colab.research.google.com/github/beamaia/cvae-clustering/blob/main/histvae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install torch \
torchvision \
matplotlib \
numpy \
scikit-learn \
opencv-python \
openslide-python \
pandas \
Pillow \
scipy \
wandb

**Based on [HistoVAE](https://github.com/beamaia/HistoVAE/blob/main/src/models.py)**

In [2]:
#torch==2.2.1
#torchvision==0.17.1
#matplotlib==3.8.3
#numpy==1.26.4
#opencv-python==4.9.0.80
#pandas==2.2.0
#Pillow==10.2.0
#scipy==1.12.0
#wandb==0.16.3

In [3]:
import torch
from torch import nn
from math import log
from torch.nn import functional as F
import torchvision.transforms as T
import numpy as np

In [4]:
class View(nn.Module):
    def __init__(self, shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

In [306]:
class VCAE(nn.Module):
    def __init__(self, inner_dim:int = 2048, dropout_rate:float = 0., device='cpu'):
        super(VCAE, self).__init__()
        self.inner_dim = inner_dim
        self.dropout_rate = dropout_rate
        self.epsilon_std = 1.0

        self.device = device

        self.encoder = self._define_encoder().to(device)
        self.decoder = self._define_decoder().to(device)

        self.z_mean = nn.Linear(128* 12* 12, self.inner_dim).to(device)
        self.z_log_var = nn.Linear(128* 12* 12, self.inner_dim).to(device)


    def _define_encoder(self, x=None):
         return nn.Sequential(nn.Conv2d(3, 16, kernel_size=3),
                            nn.LeakyReLU(),
                            nn.AvgPool2d(kernel_size=2, stride=2),

                            nn.Conv2d(16, 32, kernel_size=3),
                            nn.LeakyReLU(),
                            nn.AvgPool2d(kernel_size=2, stride=2),

                            nn.Conv2d(32, 64, kernel_size=3),
                            nn.LeakyReLU(),
                            nn.AvgPool2d(kernel_size=2, stride=2),

                            nn.Conv2d(64, 128, kernel_size=3),
                            nn.LeakyReLU(),
                            nn.AvgPool2d(kernel_size=2, stride=2),

                            nn.Flatten(),
                            nn.Dropout(self.dropout_rate)
                            )


    def _define_decoder(self):
        return nn.Sequential(nn.Dropout(self.dropout_rate),
                            nn.Linear(self.inner_dim, 128*12*12),
                            View((-1, 128, 12, 12)),

                            nn.Upsample(size=(24, 24)),
                            nn.LeakyReLU(),
                            nn.ConvTranspose2d(128, 64, kernel_size=3),

                            nn.Upsample(size=(52, 52)),
                            nn.LeakyReLU(),
                            nn.ConvTranspose2d(64, 32, kernel_size=3),


                            nn.Upsample(size=(109, 109)),
                            nn.LeakyReLU(),
                            nn.ConvTranspose2d(32, 16, kernel_size=3),

                            nn.Upsample(size=(223, 223)),
                            nn.LeakyReLU(),
                            nn.ConvTranspose2d(16, 3, kernel_size=3),
                            )

    def _resize_input(self, image):
        if image.shape[2] == 225 and image.shape[3] == 225:
            return image
        return T.Resize(size=225)(image).to(self.device)


    def encode(self, image):
        image = self._resize_input(image)
        x = self.encoder(image)
        z_mean = self.z_mean(x)
        z_log_var = self.z_log_var(x)
        return z_mean, z_log_var

    def decode(self, features):
        decoded = self.decoder(features)
        return decoded


    def sampling(self, z_mean, z_log_var):
        epsilon = torch.normal(size=(z_mean.shape[0], self.inner_dim), mean=0, std=self.epsilon_std).to(self.device)
        return z_mean + torch.exp(z_log_var) * epsilon


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z_mean, z_log_var = self.encode(x)
        z = self.sampling(z_mean, z_log_var)
        decoded = self.decoder(z)

        return z_mean, z_log_var, z, decoded



In [307]:
image = torch.randint(0, 255, (3, 3, 225, 225), dtype=torch.float32)

In [308]:
vae = VCAE(dropout_rate=0.5)

In [309]:
features = vae.encode(image)

---

In [310]:
import torchvision
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset, TensorDataset

In [311]:
import pathlib as pl
from glob import glob

In [312]:
from PIL import Image

In [313]:
class PNDBDataset(Dataset):
    """P-NDB-Dataset"""

    def __init__(self, folder_path, transform=None):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.path = folder_path
        self.classes_to_idx = {
            "carcinoma": 0,
            "no_dysplasia": 1,
            "with_dysplasia": 2
        }

        self.img_transform = T.Compose([T.PILToTensor()])
        self.transform = transform
        self._organize_image()


    def _load_image(self, path):
        im = Image.open(path)
        im = self.img_transform(im)
        im = T.Resize(size=225)(im)
        return im

    def _organize_image(self):
        carcinoma_images_path = pl.Path(f"{self.path}/carcinoma")
        no_dysplasia_images_path = pl.Path(f"{self.path}/no_dysplasia")
        dysplasia_images_path = pl.Path(f"{self.path}/dysplasia")

        carcinoma_images = glob(f'{str(carcinoma_images_path)}/*.png')
        no_dysplasia_images = glob(f'{str(no_dysplasia_images_path)}/*.png')
        dysplasia_images = glob(f'{str(dysplasia_images_path)}/*.png')

        carcinoma_image_list = [self._load_image(path) for path in carcinoma_images]
        no_dysplasia_image_list = [self._load_image(path) for path in no_dysplasia_images]
        dysplasia_image_list = [self._load_image(path) for path in dysplasia_images]

        carcinoma_targets = torch.zeros(len(carcinoma_images))
        no_dysplasia_targets = torch.ones(len(no_dysplasia_images))
        dysplasia_targets = torch.ones(len(dysplasia_images)) * 2

        images_list = []
        images_list.extend(carcinoma_image_list)
        images_list.extend(no_dysplasia_image_list)
        images_list.extend(dysplasia_image_list)

        targets_list = []
        targets_list.extend(carcinoma_targets)
        targets_list.extend(no_dysplasia_targets)
        targets_list.extend(dysplasia_targets)

        self.data = torch.tensor(np.array(images_list)) / 1.
        self.target = torch.tensor(targets_list)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        target = self.target[idx]

        if self.transform:
            image = self.transform(image)

        return image, target

In [314]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [315]:
!pip install patool
import patoolib



In [316]:
#patoolib.extract_archive('/content/drive/MyDrive/dataset/data_train_test.zip')

In [317]:
learning_rate = 1e-3
batch_size = 32
epochs = 200

In [318]:
train_dataset = PNDBDataset("/content/data/train")
test_dataset = PNDBDataset("/content/data/test")



In [362]:
train_dataset = torchvision.datasets.CIFAR10(root="/content/data/", train=True, download = True, transform=T.Compose([T.ToTensor()]))
test_dataset = torchvision.datasets.CIFAR10(root="/content/data/", train=False, download = True, transform=T.Compose([T.ToTensor()]))

Files already downloaded and verified
Files already downloaded and verified


In [363]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [364]:
optimizer = torch.optim.SGD(vae.parameters(), lr=learning_rate)

In [365]:
loss_fn = F.mse_loss

In [366]:
!pip install wandb
import wandb



In [367]:
wandb.login()



True

In [368]:
alpha = 1
beta = 1

In [369]:
wandb.init(
        project="tcc",
        config={
            "epochs": epochs,
            "batch_size": batch_size,
            "lr": learning_rate,
            "optimizer": "SGD",
            "alpha": alpha,
            "beta": beta
            })

VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [370]:
device = torch.device('cuda')

In [373]:
def _resize_input(image, device):
    if image.shape[2] == 225 and image.shape[3] == 225:
        return image
    return T.Resize(size=225)(image).to(device)

In [371]:
vae = VCAE(dropout_rate=0.1, device=device)

In [375]:
vae.train()
total_loss = []
n_steps_per_epoch = len(train_dataloader)
for epoch in range(epochs):
    epochs_loss = []
    print("Epoch:", epoch)
    for i, data in enumerate(train_dataloader):
        # Compute prediction and loss
        X, y = data
        X, y = X.to(device), y.to(device)
        X = _resize_input(X, device)
        print(X, y)
        mean, logvar, sampling, decoded = vae(X)

        kl = - 0.5 * torch.mean(1 + logvar - mean**2 - torch.exp(logvar), axis=1)
        # kl_mean = kl.mean()
        batchsize = X.shape[0]

        x_flat = X.flatten()
        decoded_flat = decoded.flatten()

        pixelwise = loss_fn(x_flat, decoded_flat, reduction='none')
        # pixelwise_mean = pixelwise.mean()


        loss = alpha * pixelwise.mean() + beta * kl.mean()
        loss = loss.mean()

        if not i % 10:
          print("pixelwise", pixelwise)
          print("kl", kl)
          print("loss", loss)
          print()

        optimizer.zero_grad()

        # Backpropagation
        loss.backward()
        optimizer.step()

        epochs_loss.append(loss.cpu().detach().numpy())
        metrics = {"train/train_loss": loss,
                    "train/epoch": (i + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch,
                    "train/kl_div": kl,
                    "train/mse_loss": pixelwise,
                    "train/epoch": epoch + 1}

        wandb.log(metrics)


Epoch: 0
tensor([[[[0.0941, 0.0941, 0.0941,  ..., 0.2235, 0.2235, 0.2235],
          [0.0941, 0.0941, 0.0941,  ..., 0.2235, 0.2235, 0.2235],
          [0.0941, 0.0941, 0.0941,  ..., 0.2235, 0.2235, 0.2235],
          ...,
          [0.0902, 0.0902, 0.0902,  ..., 0.3255, 0.3255, 0.3255],
          [0.0902, 0.0902, 0.0902,  ..., 0.3255, 0.3255, 0.3255],
          [0.0902, 0.0902, 0.0902,  ..., 0.3255, 0.3255, 0.3255]],

         [[0.0824, 0.0824, 0.0824,  ..., 0.2392, 0.2392, 0.2392],
          [0.0824, 0.0824, 0.0824,  ..., 0.2392, 0.2392, 0.2392],
          [0.0824, 0.0824, 0.0824,  ..., 0.2392, 0.2392, 0.2392],
          ...,
          [0.1882, 0.1882, 0.1882,  ..., 0.3529, 0.3529, 0.3529],
          [0.1882, 0.1882, 0.1882,  ..., 0.3529, 0.3529, 0.3529],
          [0.1882, 0.1882, 0.1882,  ..., 0.3529, 0.3529, 0.3529]],

         [[0.0824, 0.0824, 0.0824,  ..., 0.1725, 0.1725, 0.1725],
          [0.0824, 0.0824, 0.0824,  ..., 0.1725, 0.1725, 0.1725],
          [0.0824, 0.0824, 0.0824



[1;30;43mStreaming output truncated to the last 5000 lines.[0m

         [[0.6549, 0.6549, 0.6549,  ..., 0.8078, 0.8078, 0.8078],
          [0.6549, 0.6549, 0.6549,  ..., 0.8078, 0.8078, 0.8078],
          [0.6549, 0.6549, 0.6549,  ..., 0.8078, 0.8078, 0.8078],
          ...,
          [0.5294, 0.5294, 0.5294,  ..., 0.6667, 0.6667, 0.6667],
          [0.5294, 0.5294, 0.5294,  ..., 0.6667, 0.6667, 0.6667],
          [0.5294, 0.5294, 0.5294,  ..., 0.6667, 0.6667, 0.6667]],

         [[0.7020, 0.7020, 0.7020,  ..., 0.8431, 0.8431, 0.8431],
          [0.7020, 0.7020, 0.7020,  ..., 0.8431, 0.8431, 0.8431],
          [0.7020, 0.7020, 0.7020,  ..., 0.8431, 0.8431, 0.8431],
          ...,
          [0.4824, 0.4824, 0.4824,  ..., 0.6039, 0.6039, 0.6039],
          [0.4824, 0.4824, 0.4824,  ..., 0.6039, 0.6039, 0.6039],
          [0.4824, 0.4824, 0.4824,  ..., 0.6039, 0.6039, 0.6039]]],


        [[[0.4118, 0.4118, 0.4118,  ..., 0.1569, 0.1569, 0.1569],
          [0.4118, 0.4118, 0.4118,  ...,



tensor([[[[0.5294, 0.5294, 0.5294,  ..., 0.8118, 0.8118, 0.8118],
          [0.5294, 0.5294, 0.5294,  ..., 0.8118, 0.8118, 0.8118],
          [0.5294, 0.5294, 0.5294,  ..., 0.8118, 0.8118, 0.8118],
          ...,
          [0.5333, 0.5333, 0.5333,  ..., 0.5294, 0.5294, 0.5294],
          [0.5333, 0.5333, 0.5333,  ..., 0.5294, 0.5294, 0.5294],
          [0.5333, 0.5333, 0.5333,  ..., 0.5294, 0.5294, 0.5294]],

         [[0.6353, 0.6353, 0.6353,  ..., 0.9059, 0.9059, 0.9059],
          [0.6353, 0.6353, 0.6353,  ..., 0.9059, 0.9059, 0.9059],
          [0.6353, 0.6353, 0.6353,  ..., 0.9059, 0.9059, 0.9059],
          ...,
          [0.5608, 0.5608, 0.5608,  ..., 0.5490, 0.5490, 0.5490],
          [0.5608, 0.5608, 0.5608,  ..., 0.5490, 0.5490, 0.5490],
          [0.5608, 0.5608, 0.5608,  ..., 0.5490, 0.5490, 0.5490]],

         [[0.7059, 0.7059, 0.7059,  ..., 0.9373, 0.9373, 0.9373],
          [0.7059, 0.7059, 0.7059,  ..., 0.9373, 0.9373, 0.9373],
          [0.7059, 0.7059, 0.7059,  ..., 0



tensor([[[[0.9255, 0.9255, 0.9255,  ..., 0.9608, 0.9608, 0.9608],
          [0.9255, 0.9255, 0.9255,  ..., 0.9608, 0.9608, 0.9608],
          [0.9255, 0.9255, 0.9255,  ..., 0.9608, 0.9608, 0.9608],
          ...,
          [0.9686, 0.9686, 0.9686,  ..., 0.6353, 0.6353, 0.6353],
          [0.9686, 0.9686, 0.9686,  ..., 0.6353, 0.6353, 0.6353],
          [0.9686, 0.9686, 0.9686,  ..., 0.6353, 0.6353, 0.6353]],

         [[0.9137, 0.9137, 0.9137,  ..., 0.9529, 0.9529, 0.9529],
          [0.9137, 0.9137, 0.9137,  ..., 0.9529, 0.9529, 0.9529],
          [0.9137, 0.9137, 0.9137,  ..., 0.9529, 0.9529, 0.9529],
          ...,
          [0.9373, 0.9373, 0.9373,  ..., 0.6706, 0.6706, 0.6706],
          [0.9373, 0.9373, 0.9373,  ..., 0.6706, 0.6706, 0.6706],
          [0.9373, 0.9373, 0.9373,  ..., 0.6706, 0.6706, 0.6706]],

         [[0.9098, 0.9098, 0.9098,  ..., 0.9569, 0.9569, 0.9569],
          [0.9098, 0.9098, 0.9098,  ..., 0.9569, 0.9569, 0.9569],
          [0.9098, 0.9098, 0.9098,  ..., 0



tensor([[[[0.9922, 0.9922, 0.9922,  ..., 1.0000, 1.0000, 1.0000],
          [0.9922, 0.9922, 0.9922,  ..., 1.0000, 1.0000, 1.0000],
          [0.9922, 0.9922, 0.9922,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [0.9804, 0.9804, 0.9804,  ..., 0.9961, 0.9961, 0.9961],
          [0.9804, 0.9804, 0.9804,  ..., 0.9961, 0.9961, 0.9961],
          [0.9804, 0.9804, 0.9804,  ..., 0.9961, 0.9961, 0.9961]],

         [[0.8039, 0.8039, 0.8039,  ..., 0.8118, 0.8118, 0.8118],
          [0.8039, 0.8039, 0.8039,  ..., 0.8118, 0.8118, 0.8118],
          [0.8039, 0.8039, 0.8039,  ..., 0.8118, 0.8118, 0.8118],
          ...,
          [0.7961, 0.7961, 0.7961,  ..., 0.8078, 0.8078, 0.8078],
          [0.7961, 0.7961, 0.7961,  ..., 0.8078, 0.8078, 0.8078],
          [0.7961, 0.7961, 0.7961,  ..., 0.8078, 0.8078, 0.8078]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0.0039, 0.0039, 0.0039],
          [0.0000, 0.0000, 0.0000,  ..., 0



tensor([[[[0.3216, 0.3216, 0.3216,  ..., 0.6078, 0.6078, 0.6078],
          [0.3216, 0.3216, 0.3216,  ..., 0.6078, 0.6078, 0.6078],
          [0.3216, 0.3216, 0.3216,  ..., 0.6078, 0.6078, 0.6078],
          ...,
          [0.6902, 0.6902, 0.6902,  ..., 0.3059, 0.3059, 0.3059],
          [0.6902, 0.6902, 0.6902,  ..., 0.3059, 0.3059, 0.3059],
          [0.6902, 0.6902, 0.6902,  ..., 0.3059, 0.3059, 0.3059]],

         [[0.2667, 0.2667, 0.2667,  ..., 0.5176, 0.5176, 0.5176],
          [0.2667, 0.2667, 0.2667,  ..., 0.5176, 0.5176, 0.5176],
          [0.2667, 0.2667, 0.2667,  ..., 0.5176, 0.5176, 0.5176],
          ...,
          [0.6118, 0.6118, 0.6118,  ..., 0.2235, 0.2235, 0.2235],
          [0.6118, 0.6118, 0.6118,  ..., 0.2235, 0.2235, 0.2235],
          [0.6118, 0.6118, 0.6118,  ..., 0.2235, 0.2235, 0.2235]],

         [[0.1686, 0.1686, 0.1686,  ..., 0.3961, 0.3961, 0.3961],
          [0.1686, 0.1686, 0.1686,  ..., 0.3961, 0.3961, 0.3961],
          [0.1686, 0.1686, 0.1686,  ..., 0



tensor([[[[0.5882, 0.5882, 0.5882,  ..., 0.5961, 0.5961, 0.5961],
          [0.5882, 0.5882, 0.5882,  ..., 0.5961, 0.5961, 0.5961],
          [0.5882, 0.5882, 0.5882,  ..., 0.5961, 0.5961, 0.5961],
          ...,
          [0.2745, 0.2745, 0.2745,  ..., 0.4980, 0.4980, 0.4980],
          [0.2745, 0.2745, 0.2745,  ..., 0.4980, 0.4980, 0.4980],
          [0.2745, 0.2745, 0.2745,  ..., 0.4980, 0.4980, 0.4980]],

         [[0.5216, 0.5216, 0.5216,  ..., 0.5020, 0.5020, 0.5020],
          [0.5216, 0.5216, 0.5216,  ..., 0.5020, 0.5020, 0.5020],
          [0.5216, 0.5216, 0.5216,  ..., 0.5020, 0.5020, 0.5020],
          ...,
          [0.2471, 0.2471, 0.2471,  ..., 0.4431, 0.4431, 0.4431],
          [0.2471, 0.2471, 0.2471,  ..., 0.4431, 0.4431, 0.4431],
          [0.2471, 0.2471, 0.2471,  ..., 0.4431, 0.4431, 0.4431]],

         [[0.2118, 0.2118, 0.2118,  ..., 0.2510, 0.2510, 0.2510],
          [0.2118, 0.2118, 0.2118,  ..., 0.2510, 0.2510, 0.2510],
          [0.2118, 0.2118, 0.2118,  ..., 0



tensor([[[[0.7686, 0.7686, 0.7686,  ..., 0.8000, 0.8000, 0.8000],
          [0.7686, 0.7686, 0.7686,  ..., 0.8000, 0.8000, 0.8000],
          [0.7686, 0.7686, 0.7686,  ..., 0.8000, 0.8000, 0.8000],
          ...,
          [0.3686, 0.3686, 0.3686,  ..., 0.4078, 0.4078, 0.4078],
          [0.3686, 0.3686, 0.3686,  ..., 0.4078, 0.4078, 0.4078],
          [0.3686, 0.3686, 0.3686,  ..., 0.4078, 0.4078, 0.4078]],

         [[0.7569, 0.7569, 0.7569,  ..., 0.7922, 0.7922, 0.7922],
          [0.7569, 0.7569, 0.7569,  ..., 0.7922, 0.7922, 0.7922],
          [0.7569, 0.7569, 0.7569,  ..., 0.7922, 0.7922, 0.7922],
          ...,
          [0.3882, 0.3882, 0.3882,  ..., 0.4353, 0.4353, 0.4353],
          [0.3882, 0.3882, 0.3882,  ..., 0.4353, 0.4353, 0.4353],
          [0.3882, 0.3882, 0.3882,  ..., 0.4353, 0.4353, 0.4353]],

         [[0.8275, 0.8275, 0.8275,  ..., 0.8431, 0.8431, 0.8431],
          [0.8275, 0.8275, 0.8275,  ..., 0.8431, 0.8431, 0.8431],
          [0.8275, 0.8275, 0.8275,  ..., 0



tensor([[[[0.1059, 0.1059, 0.1059,  ..., 0.5216, 0.5216, 0.5216],
          [0.1059, 0.1059, 0.1059,  ..., 0.5216, 0.5216, 0.5216],
          [0.1059, 0.1059, 0.1059,  ..., 0.5216, 0.5216, 0.5216],
          ...,
          [0.3255, 0.3255, 0.3255,  ..., 0.5451, 0.5451, 0.5451],
          [0.3255, 0.3255, 0.3255,  ..., 0.5451, 0.5451, 0.5451],
          [0.3255, 0.3255, 0.3255,  ..., 0.5451, 0.5451, 0.5451]],

         [[0.1059, 0.1059, 0.1059,  ..., 0.5451, 0.5451, 0.5451],
          [0.1059, 0.1059, 0.1059,  ..., 0.5451, 0.5451, 0.5451],
          [0.1059, 0.1059, 0.1059,  ..., 0.5451, 0.5451, 0.5451],
          ...,
          [0.3294, 0.3294, 0.3294,  ..., 0.5608, 0.5608, 0.5608],
          [0.3294, 0.3294, 0.3294,  ..., 0.5608, 0.5608, 0.5608],
          [0.3294, 0.3294, 0.3294,  ..., 0.5608, 0.5608, 0.5608]],

         [[0.1059, 0.1059, 0.1059,  ..., 0.5020, 0.5020, 0.5020],
          [0.1059, 0.1059, 0.1059,  ..., 0.5020, 0.5020, 0.5020],
          [0.1059, 0.1059, 0.1059,  ..., 0



tensor([[[[0.4510, 0.4510, 0.4510,  ..., 0.1529, 0.1529, 0.1529],
          [0.4510, 0.4510, 0.4510,  ..., 0.1529, 0.1529, 0.1529],
          [0.4510, 0.4510, 0.4510,  ..., 0.1529, 0.1529, 0.1529],
          ...,
          [0.5961, 0.5961, 0.5961,  ..., 0.2039, 0.2039, 0.2039],
          [0.5961, 0.5961, 0.5961,  ..., 0.2039, 0.2039, 0.2039],
          [0.5961, 0.5961, 0.5961,  ..., 0.2039, 0.2039, 0.2039]],

         [[0.3569, 0.3569, 0.3569,  ..., 0.1451, 0.1451, 0.1451],
          [0.3569, 0.3569, 0.3569,  ..., 0.1451, 0.1451, 0.1451],
          [0.3569, 0.3569, 0.3569,  ..., 0.1451, 0.1451, 0.1451],
          ...,
          [0.7765, 0.7765, 0.7765,  ..., 0.2157, 0.2157, 0.2157],
          [0.7765, 0.7765, 0.7765,  ..., 0.2157, 0.2157, 0.2157],
          [0.7765, 0.7765, 0.7765,  ..., 0.2157, 0.2157, 0.2157]],

         [[0.3294, 0.3294, 0.3294,  ..., 0.1176, 0.1176, 0.1176],
          [0.3294, 0.3294, 0.3294,  ..., 0.1176, 0.1176, 0.1176],
          [0.3294, 0.3294, 0.3294,  ..., 0



tensor([[[[0.9569, 0.9569, 0.9569,  ..., 0.9216, 0.9216, 0.9216],
          [0.9569, 0.9569, 0.9569,  ..., 0.9216, 0.9216, 0.9216],
          [0.9569, 0.9569, 0.9569,  ..., 0.9216, 0.9216, 0.9216],
          ...,
          [0.8078, 0.8078, 0.8078,  ..., 0.9294, 0.9294, 0.9294],
          [0.8078, 0.8078, 0.8078,  ..., 0.9294, 0.9294, 0.9294],
          [0.8078, 0.8078, 0.8078,  ..., 0.9294, 0.9294, 0.9294]],

         [[0.7451, 0.7451, 0.7451,  ..., 0.7098, 0.7098, 0.7098],
          [0.7451, 0.7451, 0.7451,  ..., 0.7098, 0.7098, 0.7098],
          [0.7451, 0.7451, 0.7451,  ..., 0.7098, 0.7098, 0.7098],
          ...,
          [0.6706, 0.6706, 0.6706,  ..., 0.8039, 0.8039, 0.8039],
          [0.6706, 0.6706, 0.6706,  ..., 0.8039, 0.8039, 0.8039],
          [0.6706, 0.6706, 0.6706,  ..., 0.8039, 0.8039, 0.8039]],

         [[0.5569, 0.5569, 0.5569,  ..., 0.5333, 0.5333, 0.5333],
          [0.5569, 0.5569, 0.5569,  ..., 0.5333, 0.5333, 0.5333],
          [0.5569, 0.5569, 0.5569,  ..., 0



tensor([[[[0.2039, 0.2039, 0.2039,  ..., 0.0353, 0.0353, 0.0353],
          [0.2039, 0.2039, 0.2039,  ..., 0.0353, 0.0353, 0.0353],
          [0.2039, 0.2039, 0.2039,  ..., 0.0353, 0.0353, 0.0353],
          ...,
          [0.0745, 0.0745, 0.0745,  ..., 0.0196, 0.0196, 0.0196],
          [0.0745, 0.0745, 0.0745,  ..., 0.0196, 0.0196, 0.0196],
          [0.0745, 0.0745, 0.0745,  ..., 0.0196, 0.0196, 0.0196]],

         [[0.9804, 0.9804, 0.9804,  ..., 0.8549, 0.8549, 0.8549],
          [0.9804, 0.9804, 0.9804,  ..., 0.8549, 0.8549, 0.8549],
          [0.9804, 0.9804, 0.9804,  ..., 0.8549, 0.8549, 0.8549],
          ...,
          [0.9608, 0.9608, 0.9608,  ..., 0.6353, 0.6353, 0.6353],
          [0.9608, 0.9608, 0.9608,  ..., 0.6353, 0.6353, 0.6353],
          [0.9608, 0.9608, 0.9608,  ..., 0.6353, 0.6353, 0.6353]],

         [[0.9176, 0.9176, 0.9176,  ..., 0.7765, 0.7765, 0.7765],
          [0.9176, 0.9176, 0.9176,  ..., 0.7765, 0.7765, 0.7765],
          [0.9176, 0.9176, 0.9176,  ..., 0



tensor([[[[0.1608, 0.1608, 0.1608,  ..., 0.9294, 0.9294, 0.9294],
          [0.1608, 0.1608, 0.1608,  ..., 0.9294, 0.9294, 0.9294],
          [0.1608, 0.1608, 0.1608,  ..., 0.9294, 0.9294, 0.9294],
          ...,
          [0.4314, 0.4314, 0.4314,  ..., 0.2196, 0.2196, 0.2196],
          [0.4314, 0.4314, 0.4314,  ..., 0.2196, 0.2196, 0.2196],
          [0.4314, 0.4314, 0.4314,  ..., 0.2196, 0.2196, 0.2196]],

         [[0.1647, 0.1647, 0.1647,  ..., 0.9255, 0.9255, 0.9255],
          [0.1647, 0.1647, 0.1647,  ..., 0.9255, 0.9255, 0.9255],
          [0.1647, 0.1647, 0.1647,  ..., 0.9255, 0.9255, 0.9255],
          ...,
          [0.4353, 0.4353, 0.4353,  ..., 0.3020, 0.3020, 0.3020],
          [0.4353, 0.4353, 0.4353,  ..., 0.3020, 0.3020, 0.3020],
          [0.4353, 0.4353, 0.4353,  ..., 0.3020, 0.3020, 0.3020]],

         [[0.1451, 0.1451, 0.1451,  ..., 0.9255, 0.9255, 0.9255],
          [0.1451, 0.1451, 0.1451,  ..., 0.9255, 0.9255, 0.9255],
          [0.1451, 0.1451, 0.1451,  ..., 0



tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1



tensor([[[[0.9137, 0.9137, 0.9137,  ..., 0.9294, 0.9294, 0.9294],
          [0.9137, 0.9137, 0.9137,  ..., 0.9294, 0.9294, 0.9294],
          [0.9137, 0.9137, 0.9137,  ..., 0.9294, 0.9294, 0.9294],
          ...,
          [0.9176, 0.9176, 0.9176,  ..., 0.9294, 0.9294, 0.9294],
          [0.9176, 0.9176, 0.9176,  ..., 0.9294, 0.9294, 0.9294],
          [0.9176, 0.9176, 0.9176,  ..., 0.9294, 0.9294, 0.9294]],

         [[0.9843, 0.9843, 0.9843,  ..., 0.9647, 0.9647, 0.9647],
          [0.9843, 0.9843, 0.9843,  ..., 0.9647, 0.9647, 0.9647],
          [0.9843, 0.9843, 0.9843,  ..., 0.9647, 0.9647, 0.9647],
          ...,
          [0.9725, 0.9725, 0.9725,  ..., 0.9922, 0.9922, 0.9922],
          [0.9725, 0.9725, 0.9725,  ..., 0.9922, 0.9922, 0.9922],
          [0.9725, 0.9725, 0.9725,  ..., 0.9922, 0.9922, 0.9922]],

         [[0.8549, 0.8549, 0.8549,  ..., 0.9725, 0.9725, 0.9725],
          [0.8549, 0.8549, 0.8549,  ..., 0.9725, 0.9725, 0.9725],
          [0.8549, 0.8549, 0.8549,  ..., 0



tensor([[[[0.2196, 0.2196, 0.2196,  ..., 0.8471, 0.8471, 0.8471],
          [0.2196, 0.2196, 0.2196,  ..., 0.8471, 0.8471, 0.8471],
          [0.2196, 0.2196, 0.2196,  ..., 0.8471, 0.8471, 0.8471],
          ...,
          [0.6549, 0.6549, 0.6549,  ..., 0.8275, 0.8275, 0.8275],
          [0.6549, 0.6549, 0.6549,  ..., 0.8275, 0.8275, 0.8275],
          [0.6549, 0.6549, 0.6549,  ..., 0.8275, 0.8275, 0.8275]],

         [[0.3020, 0.3020, 0.3020,  ..., 0.8510, 0.8510, 0.8510],
          [0.3020, 0.3020, 0.3020,  ..., 0.8510, 0.8510, 0.8510],
          [0.3020, 0.3020, 0.3020,  ..., 0.8510, 0.8510, 0.8510],
          ...,
          [0.6667, 0.6667, 0.6667,  ..., 0.8275, 0.8275, 0.8275],
          [0.6667, 0.6667, 0.6667,  ..., 0.8275, 0.8275, 0.8275],
          [0.6667, 0.6667, 0.6667,  ..., 0.8275, 0.8275, 0.8275]],

         [[0.3216, 0.3216, 0.3216,  ..., 0.8314, 0.8314, 0.8314],
          [0.3216, 0.3216, 0.3216,  ..., 0.8314, 0.8314, 0.8314],
          [0.3216, 0.3216, 0.3216,  ..., 0



tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0



tensor([[[[0.0078, 0.0078, 0.0078,  ..., 0.0078, 0.0078, 0.0078],
          [0.0078, 0.0078, 0.0078,  ..., 0.0078, 0.0078, 0.0078],
          [0.0078, 0.0078, 0.0078,  ..., 0.0078, 0.0078, 0.0078],
          ...,
          [0.0039, 0.0039, 0.0039,  ..., 0.0510, 0.0510, 0.0510],
          [0.0039, 0.0039, 0.0039,  ..., 0.0510, 0.0510, 0.0510],
          [0.0039, 0.0039, 0.0039,  ..., 0.0510, 0.0510, 0.0510]],

         [[0.0039, 0.0039, 0.0039,  ..., 0.0078, 0.0078, 0.0078],
          [0.0039, 0.0039, 0.0039,  ..., 0.0078, 0.0078, 0.0078],
          [0.0039, 0.0039, 0.0039,  ..., 0.0078, 0.0078, 0.0078],
          ...,
          [0.0039, 0.0039, 0.0039,  ..., 0.0549, 0.0549, 0.0549],
          [0.0039, 0.0039, 0.0039,  ..., 0.0549, 0.0549, 0.0549],
          [0.0039, 0.0039, 0.0039,  ..., 0.0549, 0.0549, 0.0549]],

         [[0.0196, 0.0196, 0.0196,  ..., 0.0235, 0.0235, 0.0235],
          [0.0196, 0.0196, 0.0196,  ..., 0.0235, 0.0235, 0.0235],
          [0.0196, 0.0196, 0.0196,  ..., 0



tensor([[[[0.2784, 0.2784, 0.2784,  ..., 0.4941, 0.4941, 0.4941],
          [0.2784, 0.2784, 0.2784,  ..., 0.4941, 0.4941, 0.4941],
          [0.2784, 0.2784, 0.2784,  ..., 0.4941, 0.4941, 0.4941],
          ...,
          [0.6039, 0.6039, 0.6039,  ..., 0.2706, 0.2706, 0.2706],
          [0.6039, 0.6039, 0.6039,  ..., 0.2706, 0.2706, 0.2706],
          [0.6039, 0.6039, 0.6039,  ..., 0.2706, 0.2706, 0.2706]],

         [[0.3098, 0.3098, 0.3098,  ..., 0.5020, 0.5020, 0.5020],
          [0.3098, 0.3098, 0.3098,  ..., 0.5020, 0.5020, 0.5020],
          [0.3098, 0.3098, 0.3098,  ..., 0.5020, 0.5020, 0.5020],
          ...,
          [0.5765, 0.5765, 0.5765,  ..., 0.2549, 0.2549, 0.2549],
          [0.5765, 0.5765, 0.5765,  ..., 0.2549, 0.2549, 0.2549],
          [0.5765, 0.5765, 0.5765,  ..., 0.2549, 0.2549, 0.2549]],

         [[0.2627, 0.2627, 0.2627,  ..., 0.4157, 0.4157, 0.4157],
          [0.2627, 0.2627, 0.2627,  ..., 0.4157, 0.4157, 0.4157],
          [0.2627, 0.2627, 0.2627,  ..., 0



tensor([[[[0.0941, 0.0941, 0.0941,  ..., 0.5294, 0.5294, 0.5294],
          [0.0941, 0.0941, 0.0941,  ..., 0.5294, 0.5294, 0.5294],
          [0.0941, 0.0941, 0.0941,  ..., 0.5294, 0.5294, 0.5294],
          ...,
          [0.5020, 0.5020, 0.5020,  ..., 0.4706, 0.4706, 0.4706],
          [0.5020, 0.5020, 0.5020,  ..., 0.4706, 0.4706, 0.4706],
          [0.5020, 0.5020, 0.5020,  ..., 0.4706, 0.4706, 0.4706]],

         [[0.1294, 0.1294, 0.1294,  ..., 0.5843, 0.5843, 0.5843],
          [0.1294, 0.1294, 0.1294,  ..., 0.5843, 0.5843, 0.5843],
          [0.1294, 0.1294, 0.1294,  ..., 0.5843, 0.5843, 0.5843],
          ...,
          [0.4863, 0.4863, 0.4863,  ..., 0.4549, 0.4549, 0.4549],
          [0.4863, 0.4863, 0.4863,  ..., 0.4549, 0.4549, 0.4549],
          [0.4863, 0.4863, 0.4863,  ..., 0.4549, 0.4549, 0.4549]],

         [[0.1647, 0.1647, 0.1647,  ..., 0.6667, 0.6667, 0.6667],
          [0.1647, 0.1647, 0.1647,  ..., 0.6667, 0.6667, 0.6667],
          [0.1647, 0.1647, 0.1647,  ..., 0



tensor([[[[0.5412, 0.5412, 0.5412,  ..., 0.4275, 0.4275, 0.4275],
          [0.5412, 0.5412, 0.5412,  ..., 0.4275, 0.4275, 0.4275],
          [0.5412, 0.5412, 0.5412,  ..., 0.4275, 0.4275, 0.4275],
          ...,
          [0.6314, 0.6314, 0.6314,  ..., 0.5725, 0.5725, 0.5725],
          [0.6314, 0.6314, 0.6314,  ..., 0.5725, 0.5725, 0.5725],
          [0.6314, 0.6314, 0.6314,  ..., 0.5725, 0.5725, 0.5725]],

         [[0.4431, 0.4431, 0.4431,  ..., 0.4706, 0.4706, 0.4706],
          [0.4431, 0.4431, 0.4431,  ..., 0.4706, 0.4706, 0.4706],
          [0.4431, 0.4431, 0.4431,  ..., 0.4706, 0.4706, 0.4706],
          ...,
          [0.6784, 0.6784, 0.6784,  ..., 0.6863, 0.6863, 0.6863],
          [0.6784, 0.6784, 0.6784,  ..., 0.6863, 0.6863, 0.6863],
          [0.6784, 0.6784, 0.6784,  ..., 0.6863, 0.6863, 0.6863]],

         [[0.3569, 0.3569, 0.3569,  ..., 0.4510, 0.4510, 0.4510],
          [0.3569, 0.3569, 0.3569,  ..., 0.4510, 0.4510, 0.4510],
          [0.3569, 0.3569, 0.3569,  ..., 0



tensor([[[[0.8510, 0.8510, 0.8510,  ..., 1.0000, 1.0000, 1.0000],
          [0.8510, 0.8510, 0.8510,  ..., 1.0000, 1.0000, 1.0000],
          [0.8510, 0.8510, 0.8510,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [0.4549, 0.4549, 0.4549,  ..., 0.4196, 0.4196, 0.4196],
          [0.4549, 0.4549, 0.4549,  ..., 0.4196, 0.4196, 0.4196],
          [0.4549, 0.4549, 0.4549,  ..., 0.4196, 0.4196, 0.4196]],

         [[0.8471, 0.8471, 0.8471,  ..., 1.0000, 1.0000, 1.0000],
          [0.8471, 0.8471, 0.8471,  ..., 1.0000, 1.0000, 1.0000],
          [0.8471, 0.8471, 0.8471,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [0.4431, 0.4431, 0.4431,  ..., 0.4118, 0.4118, 0.4118],
          [0.4431, 0.4431, 0.4431,  ..., 0.4118, 0.4118, 0.4118],
          [0.4431, 0.4431, 0.4431,  ..., 0.4118, 0.4118, 0.4118]],

         [[0.8627, 0.8627, 0.8627,  ..., 1.0000, 1.0000, 1.0000],
          [0.8627, 0.8627, 0.8627,  ..., 1.0000, 1.0000, 1.0000],
          [0.8627, 0.8627, 0.8627,  ..., 1



tensor([[[[0.4627, 0.4627, 0.4627,  ..., 0.4627, 0.4627, 0.4627],
          [0.4627, 0.4627, 0.4627,  ..., 0.4627, 0.4627, 0.4627],
          [0.4627, 0.4627, 0.4627,  ..., 0.4627, 0.4627, 0.4627],
          ...,
          [0.3725, 0.3725, 0.3725,  ..., 0.3255, 0.3255, 0.3255],
          [0.3725, 0.3725, 0.3725,  ..., 0.3255, 0.3255, 0.3255],
          [0.3725, 0.3725, 0.3725,  ..., 0.3255, 0.3255, 0.3255]],

         [[0.5294, 0.5294, 0.5294,  ..., 0.5412, 0.5412, 0.5412],
          [0.5294, 0.5294, 0.5294,  ..., 0.5412, 0.5412, 0.5412],
          [0.5294, 0.5294, 0.5294,  ..., 0.5412, 0.5412, 0.5412],
          ...,
          [0.4000, 0.4000, 0.4000,  ..., 0.3569, 0.3569, 0.3569],
          [0.4000, 0.4000, 0.4000,  ..., 0.3569, 0.3569, 0.3569],
          [0.4000, 0.4000, 0.4000,  ..., 0.3569, 0.3569, 0.3569]],

         [[0.3176, 0.3176, 0.3176,  ..., 0.3059, 0.3059, 0.3059],
          [0.3176, 0.3176, 0.3176,  ..., 0.3059, 0.3059, 0.3059],
          [0.3176, 0.3176, 0.3176,  ..., 0



tensor([[[[0.4941, 0.4941, 0.4941,  ..., 0.4745, 0.4745, 0.4745],
          [0.4941, 0.4941, 0.4941,  ..., 0.4745, 0.4745, 0.4745],
          [0.4941, 0.4941, 0.4941,  ..., 0.4745, 0.4745, 0.4745],
          ...,
          [0.1373, 0.1373, 0.1373,  ..., 0.0980, 0.0980, 0.0980],
          [0.1373, 0.1373, 0.1373,  ..., 0.0980, 0.0980, 0.0980],
          [0.1373, 0.1373, 0.1373,  ..., 0.0980, 0.0980, 0.0980]],

         [[0.4039, 0.4039, 0.4039,  ..., 0.3725, 0.3725, 0.3725],
          [0.4039, 0.4039, 0.4039,  ..., 0.3725, 0.3725, 0.3725],
          [0.4039, 0.4039, 0.4039,  ..., 0.3725, 0.3725, 0.3725],
          ...,
          [0.1333, 0.1333, 0.1333,  ..., 0.2118, 0.2118, 0.2118],
          [0.1333, 0.1333, 0.1333,  ..., 0.2118, 0.2118, 0.2118],
          [0.1333, 0.1333, 0.1333,  ..., 0.2118, 0.2118, 0.2118]],

         [[0.2824, 0.2824, 0.2824,  ..., 0.2588, 0.2588, 0.2588],
          [0.2824, 0.2824, 0.2824,  ..., 0.2588, 0.2588, 0.2588],
          [0.2824, 0.2824, 0.2824,  ..., 0



tensor([[[[0.8353, 0.8353, 0.8353,  ..., 0.8314, 0.8314, 0.8314],
          [0.8353, 0.8353, 0.8353,  ..., 0.8314, 0.8314, 0.8314],
          [0.8353, 0.8353, 0.8353,  ..., 0.8314, 0.8314, 0.8314],
          ...,
          [0.9059, 0.9059, 0.9059,  ..., 0.8039, 0.8039, 0.8039],
          [0.9059, 0.9059, 0.9059,  ..., 0.8039, 0.8039, 0.8039],
          [0.9059, 0.9059, 0.9059,  ..., 0.8039, 0.8039, 0.8039]],

         [[0.7412, 0.7412, 0.7412,  ..., 0.7961, 0.7961, 0.7961],
          [0.7412, 0.7412, 0.7412,  ..., 0.7961, 0.7961, 0.7961],
          [0.7412, 0.7412, 0.7412,  ..., 0.7961, 0.7961, 0.7961],
          ...,
          [0.7725, 0.7725, 0.7725,  ..., 0.7922, 0.7922, 0.7922],
          [0.7725, 0.7725, 0.7725,  ..., 0.7922, 0.7922, 0.7922],
          [0.7725, 0.7725, 0.7725,  ..., 0.7922, 0.7922, 0.7922]],

         [[0.5686, 0.5686, 0.5686,  ..., 0.6745, 0.6745, 0.6745],
          [0.5686, 0.5686, 0.5686,  ..., 0.6745, 0.6745, 0.6745],
          [0.5686, 0.5686, 0.5686,  ..., 0



tensor([[[[0.0118, 0.0118, 0.0118,  ..., 0.3216, 0.3216, 0.3216],
          [0.0118, 0.0118, 0.0118,  ..., 0.3216, 0.3216, 0.3216],
          [0.0118, 0.0118, 0.0118,  ..., 0.3216, 0.3216, 0.3216],
          ...,
          [0.3216, 0.3216, 0.3216,  ..., 0.4863, 0.4863, 0.4863],
          [0.3216, 0.3216, 0.3216,  ..., 0.4863, 0.4863, 0.4863],
          [0.3216, 0.3216, 0.3216,  ..., 0.4863, 0.4863, 0.4863]],

         [[0.0314, 0.0314, 0.0314,  ..., 0.3961, 0.3961, 0.3961],
          [0.0314, 0.0314, 0.0314,  ..., 0.3961, 0.3961, 0.3961],
          [0.0314, 0.0314, 0.0314,  ..., 0.3961, 0.3961, 0.3961],
          ...,
          [0.3176, 0.3176, 0.3176,  ..., 0.5569, 0.5569, 0.5569],
          [0.3176, 0.3176, 0.3176,  ..., 0.5569, 0.5569, 0.5569],
          [0.3176, 0.3176, 0.3176,  ..., 0.5569, 0.5569, 0.5569]],

         [[0.0118, 0.0118, 0.0118,  ..., 0.3333, 0.3333, 0.3333],
          [0.0118, 0.0118, 0.0118,  ..., 0.3333, 0.3333, 0.3333],
          [0.0118, 0.0118, 0.0118,  ..., 0



tensor([[[[1.0000, 1.0000, 1.0000,  ..., 0.6353, 0.6353, 0.6353],
          [1.0000, 1.0000, 1.0000,  ..., 0.6353, 0.6353, 0.6353],
          [1.0000, 1.0000, 1.0000,  ..., 0.6353, 0.6353, 0.6353],
          ...,
          [0.5059, 0.5059, 0.5059,  ..., 0.5686, 0.5686, 0.5686],
          [0.5059, 0.5059, 0.5059,  ..., 0.5686, 0.5686, 0.5686],
          [0.5059, 0.5059, 0.5059,  ..., 0.5686, 0.5686, 0.5686]],

         [[1.0000, 1.0000, 1.0000,  ..., 0.5725, 0.5725, 0.5725],
          [1.0000, 1.0000, 1.0000,  ..., 0.5725, 0.5725, 0.5725],
          [1.0000, 1.0000, 1.0000,  ..., 0.5725, 0.5725, 0.5725],
          ...,
          [0.4471, 0.4471, 0.4471,  ..., 0.4745, 0.4745, 0.4745],
          [0.4471, 0.4471, 0.4471,  ..., 0.4745, 0.4745, 0.4745],
          [0.4471, 0.4471, 0.4471,  ..., 0.4745, 0.4745, 0.4745]],

         [[1.0000, 1.0000, 1.0000,  ..., 0.5373, 0.5373, 0.5373],
          [1.0000, 1.0000, 1.0000,  ..., 0.5373, 0.5373, 0.5373],
          [1.0000, 1.0000, 1.0000,  ..., 0



tensor([[[[0.7216, 0.7216, 0.7216,  ..., 0.7098, 0.7098, 0.7098],
          [0.7216, 0.7216, 0.7216,  ..., 0.7098, 0.7098, 0.7098],
          [0.7216, 0.7216, 0.7216,  ..., 0.7098, 0.7098, 0.7098],
          ...,
          [0.7569, 0.7569, 0.7569,  ..., 0.7294, 0.7294, 0.7294],
          [0.7569, 0.7569, 0.7569,  ..., 0.7294, 0.7294, 0.7294],
          [0.7569, 0.7569, 0.7569,  ..., 0.7294, 0.7294, 0.7294]],

         [[0.6941, 0.6941, 0.6941,  ..., 0.6784, 0.6784, 0.6784],
          [0.6941, 0.6941, 0.6941,  ..., 0.6784, 0.6784, 0.6784],
          [0.6941, 0.6941, 0.6941,  ..., 0.6784, 0.6784, 0.6784],
          ...,
          [0.7333, 0.7333, 0.7333,  ..., 0.7255, 0.7255, 0.7255],
          [0.7333, 0.7333, 0.7333,  ..., 0.7255, 0.7255, 0.7255],
          [0.7333, 0.7333, 0.7333,  ..., 0.7255, 0.7255, 0.7255]],

         [[0.5804, 0.5804, 0.5804,  ..., 0.5882, 0.5882, 0.5882],
          [0.5804, 0.5804, 0.5804,  ..., 0.5882, 0.5882, 0.5882],
          [0.5804, 0.5804, 0.5804,  ..., 0



tensor([[[[0.5765, 0.5765, 0.5765,  ..., 0.2824, 0.2824, 0.2824],
          [0.5765, 0.5765, 0.5765,  ..., 0.2824, 0.2824, 0.2824],
          [0.5765, 0.5765, 0.5765,  ..., 0.2824, 0.2824, 0.2824],
          ...,
          [0.5608, 0.5608, 0.5608,  ..., 0.4353, 0.4353, 0.4353],
          [0.5608, 0.5608, 0.5608,  ..., 0.4353, 0.4353, 0.4353],
          [0.5608, 0.5608, 0.5608,  ..., 0.4353, 0.4353, 0.4353]],

         [[0.6353, 0.6353, 0.6353,  ..., 0.2863, 0.2863, 0.2863],
          [0.6353, 0.6353, 0.6353,  ..., 0.2863, 0.2863, 0.2863],
          [0.6353, 0.6353, 0.6353,  ..., 0.2863, 0.2863, 0.2863],
          ...,
          [0.6275, 0.6275, 0.6275,  ..., 0.5020, 0.5020, 0.5020],
          [0.6275, 0.6275, 0.6275,  ..., 0.5020, 0.5020, 0.5020],
          [0.6275, 0.6275, 0.6275,  ..., 0.5020, 0.5020, 0.5020]],

         [[0.7176, 0.7176, 0.7176,  ..., 0.3059, 0.3059, 0.3059],
          [0.7176, 0.7176, 0.7176,  ..., 0.3059, 0.3059, 0.3059],
          [0.7176, 0.7176, 0.7176,  ..., 0



tensor([[[[0.3176, 0.3176, 0.3176,  ..., 0.2980, 0.2980, 0.2980],
          [0.3176, 0.3176, 0.3176,  ..., 0.2980, 0.2980, 0.2980],
          [0.3176, 0.3176, 0.3176,  ..., 0.2980, 0.2980, 0.2980],
          ...,
          [0.5294, 0.5294, 0.5294,  ..., 0.5490, 0.5490, 0.5490],
          [0.5294, 0.5294, 0.5294,  ..., 0.5490, 0.5490, 0.5490],
          [0.5294, 0.5294, 0.5294,  ..., 0.5490, 0.5490, 0.5490]],

         [[0.4471, 0.4471, 0.4471,  ..., 0.4118, 0.4118, 0.4118],
          [0.4471, 0.4471, 0.4471,  ..., 0.4118, 0.4118, 0.4118],
          [0.4471, 0.4471, 0.4471,  ..., 0.4118, 0.4118, 0.4118],
          ...,
          [0.4824, 0.4824, 0.4824,  ..., 0.5294, 0.5294, 0.5294],
          [0.4824, 0.4824, 0.4824,  ..., 0.5294, 0.5294, 0.5294],
          [0.4824, 0.4824, 0.4824,  ..., 0.5294, 0.5294, 0.5294]],

         [[0.8039, 0.8039, 0.8039,  ..., 0.7843, 0.7843, 0.7843],
          [0.8039, 0.8039, 0.8039,  ..., 0.7843, 0.7843, 0.7843],
          [0.8039, 0.8039, 0.8039,  ..., 0



tensor([[[[0.4706, 0.4706, 0.4706,  ..., 0.6275, 0.6275, 0.6275],
          [0.4706, 0.4706, 0.4706,  ..., 0.6275, 0.6275, 0.6275],
          [0.4706, 0.4706, 0.4706,  ..., 0.6275, 0.6275, 0.6275],
          ...,
          [0.6863, 0.6863, 0.6863,  ..., 0.2353, 0.2353, 0.2353],
          [0.6863, 0.6863, 0.6863,  ..., 0.2353, 0.2353, 0.2353],
          [0.6863, 0.6863, 0.6863,  ..., 0.2353, 0.2353, 0.2353]],

         [[0.4078, 0.4078, 0.4078,  ..., 0.6196, 0.6196, 0.6196],
          [0.4078, 0.4078, 0.4078,  ..., 0.6196, 0.6196, 0.6196],
          [0.4078, 0.4078, 0.4078,  ..., 0.6196, 0.6196, 0.6196],
          ...,
          [0.5686, 0.5686, 0.5686,  ..., 0.2588, 0.2588, 0.2588],
          [0.5686, 0.5686, 0.5686,  ..., 0.2588, 0.2588, 0.2588],
          [0.5686, 0.5686, 0.5686,  ..., 0.2588, 0.2588, 0.2588]],

         [[0.4588, 0.4588, 0.4588,  ..., 0.6941, 0.6941, 0.6941],
          [0.4588, 0.4588, 0.4588,  ..., 0.6941, 0.6941, 0.6941],
          [0.4588, 0.4588, 0.4588,  ..., 0



tensor([[[[0.6902, 0.6902, 0.6902,  ..., 0.6824, 0.6824, 0.6824],
          [0.6902, 0.6902, 0.6902,  ..., 0.6824, 0.6824, 0.6824],
          [0.6902, 0.6902, 0.6902,  ..., 0.6824, 0.6824, 0.6824],
          ...,
          [0.7294, 0.7294, 0.7294,  ..., 0.6588, 0.6588, 0.6588],
          [0.7294, 0.7294, 0.7294,  ..., 0.6588, 0.6588, 0.6588],
          [0.7294, 0.7294, 0.7294,  ..., 0.6588, 0.6588, 0.6588]],

         [[0.6902, 0.6902, 0.6902,  ..., 0.6824, 0.6824, 0.6824],
          [0.6902, 0.6902, 0.6902,  ..., 0.6824, 0.6824, 0.6824],
          [0.6902, 0.6902, 0.6902,  ..., 0.6824, 0.6824, 0.6824],
          ...,
          [0.7294, 0.7294, 0.7294,  ..., 0.6588, 0.6588, 0.6588],
          [0.7294, 0.7294, 0.7294,  ..., 0.6588, 0.6588, 0.6588],
          [0.7294, 0.7294, 0.7294,  ..., 0.6588, 0.6588, 0.6588]],

         [[0.6902, 0.6902, 0.6902,  ..., 0.6824, 0.6824, 0.6824],
          [0.6902, 0.6902, 0.6902,  ..., 0.6824, 0.6824, 0.6824],
          [0.6902, 0.6902, 0.6902,  ..., 0



tensor([[[[0.4157, 0.4157, 0.4157,  ..., 0.4118, 0.4118, 0.4118],
          [0.4157, 0.4157, 0.4157,  ..., 0.4118, 0.4118, 0.4118],
          [0.4157, 0.4157, 0.4157,  ..., 0.4118, 0.4118, 0.4118],
          ...,
          [0.5176, 0.5176, 0.5176,  ..., 0.4980, 0.4980, 0.4980],
          [0.5176, 0.5176, 0.5176,  ..., 0.4980, 0.4980, 0.4980],
          [0.5176, 0.5176, 0.5176,  ..., 0.4980, 0.4980, 0.4980]],

         [[0.5020, 0.5020, 0.5020,  ..., 0.4863, 0.4863, 0.4863],
          [0.5020, 0.5020, 0.5020,  ..., 0.4863, 0.4863, 0.4863],
          [0.5020, 0.5020, 0.5020,  ..., 0.4863, 0.4863, 0.4863],
          ...,
          [0.5961, 0.5961, 0.5961,  ..., 0.5725, 0.5725, 0.5725],
          [0.5961, 0.5961, 0.5961,  ..., 0.5725, 0.5725, 0.5725],
          [0.5961, 0.5961, 0.5961,  ..., 0.5725, 0.5725, 0.5725]],

         [[0.8118, 0.8118, 0.8118,  ..., 0.7843, 0.7843, 0.7843],
          [0.8118, 0.8118, 0.8118,  ..., 0.7843, 0.7843, 0.7843],
          [0.8118, 0.8118, 0.8118,  ..., 0



tensor([[[[0.5333, 0.5333, 0.5333,  ..., 0.4275, 0.4275, 0.4275],
          [0.5333, 0.5333, 0.5333,  ..., 0.4275, 0.4275, 0.4275],
          [0.5333, 0.5333, 0.5333,  ..., 0.4275, 0.4275, 0.4275],
          ...,
          [0.5882, 0.5882, 0.5882,  ..., 0.3882, 0.3882, 0.3882],
          [0.5882, 0.5882, 0.5882,  ..., 0.3882, 0.3882, 0.3882],
          [0.5882, 0.5882, 0.5882,  ..., 0.3882, 0.3882, 0.3882]],

         [[0.5294, 0.5294, 0.5294,  ..., 0.3961, 0.3961, 0.3961],
          [0.5294, 0.5294, 0.5294,  ..., 0.3961, 0.3961, 0.3961],
          [0.5294, 0.5294, 0.5294,  ..., 0.3961, 0.3961, 0.3961],
          ...,
          [0.5373, 0.5373, 0.5373,  ..., 0.3529, 0.3529, 0.3529],
          [0.5373, 0.5373, 0.5373,  ..., 0.3529, 0.3529, 0.3529],
          [0.5373, 0.5373, 0.5373,  ..., 0.3529, 0.3529, 0.3529]],

         [[0.4078, 0.4078, 0.4078,  ..., 0.3804, 0.3804, 0.3804],
          [0.4078, 0.4078, 0.4078,  ..., 0.3804, 0.3804, 0.3804],
          [0.4078, 0.4078, 0.4078,  ..., 0



tensor([[[[0.4157, 0.4157, 0.4157,  ..., 0.4471, 0.4471, 0.4471],
          [0.4157, 0.4157, 0.4157,  ..., 0.4471, 0.4471, 0.4471],
          [0.4157, 0.4157, 0.4157,  ..., 0.4471, 0.4471, 0.4471],
          ...,
          [0.7686, 0.7686, 0.7686,  ..., 0.7765, 0.7765, 0.7765],
          [0.7686, 0.7686, 0.7686,  ..., 0.7765, 0.7765, 0.7765],
          [0.7686, 0.7686, 0.7686,  ..., 0.7765, 0.7765, 0.7765]],

         [[0.4275, 0.4275, 0.4275,  ..., 0.4235, 0.4235, 0.4235],
          [0.4275, 0.4275, 0.4275,  ..., 0.4235, 0.4235, 0.4235],
          [0.4275, 0.4275, 0.4275,  ..., 0.4235, 0.4235, 0.4235],
          ...,
          [0.6667, 0.6667, 0.6667,  ..., 0.6863, 0.6863, 0.6863],
          [0.6667, 0.6667, 0.6667,  ..., 0.6863, 0.6863, 0.6863],
          [0.6667, 0.6667, 0.6667,  ..., 0.6863, 0.6863, 0.6863]],

         [[0.3137, 0.3137, 0.3137,  ..., 0.4235, 0.4235, 0.4235],
          [0.3137, 0.3137, 0.3137,  ..., 0.4235, 0.4235, 0.4235],
          [0.3137, 0.3137, 0.3137,  ..., 0



tensor([[[[0.4706, 0.4706, 0.4706,  ..., 0.4392, 0.4392, 0.4392],
          [0.4706, 0.4706, 0.4706,  ..., 0.4392, 0.4392, 0.4392],
          [0.4706, 0.4706, 0.4706,  ..., 0.4392, 0.4392, 0.4392],
          ...,
          [0.4196, 0.4196, 0.4196,  ..., 0.5804, 0.5804, 0.5804],
          [0.4196, 0.4196, 0.4196,  ..., 0.5804, 0.5804, 0.5804],
          [0.4196, 0.4196, 0.4196,  ..., 0.5804, 0.5804, 0.5804]],

         [[0.4353, 0.4353, 0.4353,  ..., 0.4118, 0.4118, 0.4118],
          [0.4353, 0.4353, 0.4353,  ..., 0.4118, 0.4118, 0.4118],
          [0.4353, 0.4353, 0.4353,  ..., 0.4118, 0.4118, 0.4118],
          ...,
          [0.3961, 0.3961, 0.3961,  ..., 0.5412, 0.5412, 0.5412],
          [0.3961, 0.3961, 0.3961,  ..., 0.5412, 0.5412, 0.5412],
          [0.3961, 0.3961, 0.3961,  ..., 0.5412, 0.5412, 0.5412]],

         [[0.4118, 0.4118, 0.4118,  ..., 0.4392, 0.4392, 0.4392],
          [0.4118, 0.4118, 0.4118,  ..., 0.4392, 0.4392, 0.4392],
          [0.4118, 0.4118, 0.4118,  ..., 0



tensor([[[[1.0000, 1.0000, 1.0000,  ..., 0.6588, 0.6588, 0.6588],
          [1.0000, 1.0000, 1.0000,  ..., 0.6588, 0.6588, 0.6588],
          [1.0000, 1.0000, 1.0000,  ..., 0.6588, 0.6588, 0.6588],
          ...,
          [0.3922, 0.3922, 0.3922,  ..., 0.9922, 0.9922, 0.9922],
          [0.3922, 0.3922, 0.3922,  ..., 0.9922, 0.9922, 0.9922],
          [0.3922, 0.3922, 0.3922,  ..., 0.9922, 0.9922, 0.9922]],

         [[1.0000, 1.0000, 1.0000,  ..., 0.2784, 0.2784, 0.2784],
          [1.0000, 1.0000, 1.0000,  ..., 0.2784, 0.2784, 0.2784],
          [1.0000, 1.0000, 1.0000,  ..., 0.2784, 0.2784, 0.2784],
          ...,
          [0.3804, 0.3804, 0.3804,  ..., 1.0000, 1.0000, 1.0000],
          [0.3804, 0.3804, 0.3804,  ..., 1.0000, 1.0000, 1.0000],
          [0.3804, 0.3804, 0.3804,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 0.3098, 0.3098, 0.3098],
          [1.0000, 1.0000, 1.0000,  ..., 0.3098, 0.3098, 0.3098],
          [1.0000, 1.0000, 1.0000,  ..., 0



tensor([[[[0.0588, 0.0588, 0.0588,  ..., 0.0941, 0.0941, 0.0941],
          [0.0588, 0.0588, 0.0588,  ..., 0.0941, 0.0941, 0.0941],
          [0.0588, 0.0588, 0.0588,  ..., 0.0941, 0.0941, 0.0941],
          ...,
          [0.0667, 0.0667, 0.0667,  ..., 0.2000, 0.2000, 0.2000],
          [0.0667, 0.0667, 0.0667,  ..., 0.2000, 0.2000, 0.2000],
          [0.0667, 0.0667, 0.0667,  ..., 0.2000, 0.2000, 0.2000]],

         [[0.0588, 0.0588, 0.0588,  ..., 0.0902, 0.0902, 0.0902],
          [0.0588, 0.0588, 0.0588,  ..., 0.0902, 0.0902, 0.0902],
          [0.0588, 0.0588, 0.0588,  ..., 0.0902, 0.0902, 0.0902],
          ...,
          [0.0745, 0.0745, 0.0745,  ..., 0.3647, 0.3647, 0.3647],
          [0.0745, 0.0745, 0.0745,  ..., 0.3647, 0.3647, 0.3647],
          [0.0745, 0.0745, 0.0745,  ..., 0.3647, 0.3647, 0.3647]],

         [[0.0588, 0.0588, 0.0588,  ..., 0.0824, 0.0824, 0.0824],
          [0.0588, 0.0588, 0.0588,  ..., 0.0824, 0.0824, 0.0824],
          [0.0588, 0.0588, 0.0588,  ..., 0



tensor([[[[0.2980, 0.2980, 0.2980,  ..., 0.2902, 0.2902, 0.2902],
          [0.2980, 0.2980, 0.2980,  ..., 0.2902, 0.2902, 0.2902],
          [0.2980, 0.2980, 0.2980,  ..., 0.2902, 0.2902, 0.2902],
          ...,
          [0.3569, 0.3569, 0.3569,  ..., 0.3961, 0.3961, 0.3961],
          [0.3569, 0.3569, 0.3569,  ..., 0.3961, 0.3961, 0.3961],
          [0.3569, 0.3569, 0.3569,  ..., 0.3961, 0.3961, 0.3961]],

         [[0.2784, 0.2784, 0.2784,  ..., 0.2549, 0.2549, 0.2549],
          [0.2784, 0.2784, 0.2784,  ..., 0.2549, 0.2549, 0.2549],
          [0.2784, 0.2784, 0.2784,  ..., 0.2549, 0.2549, 0.2549],
          ...,
          [0.3216, 0.3216, 0.3216,  ..., 0.3686, 0.3686, 0.3686],
          [0.3216, 0.3216, 0.3216,  ..., 0.3686, 0.3686, 0.3686],
          [0.3216, 0.3216, 0.3216,  ..., 0.3686, 0.3686, 0.3686]],

         [[0.2549, 0.2549, 0.2549,  ..., 0.2510, 0.2510, 0.2510],
          [0.2549, 0.2549, 0.2549,  ..., 0.2510, 0.2510, 0.2510],
          [0.2549, 0.2549, 0.2549,  ..., 0



tensor([[[[0.5255, 0.5255, 0.5255,  ..., 0.4902, 0.4902, 0.4902],
          [0.5255, 0.5255, 0.5255,  ..., 0.4902, 0.4902, 0.4902],
          [0.5255, 0.5255, 0.5255,  ..., 0.4902, 0.4902, 0.4902],
          ...,
          [0.5255, 0.5255, 0.5255,  ..., 0.4431, 0.4431, 0.4431],
          [0.5255, 0.5255, 0.5255,  ..., 0.4431, 0.4431, 0.4431],
          [0.5255, 0.5255, 0.5255,  ..., 0.4431, 0.4431, 0.4431]],

         [[0.4235, 0.4235, 0.4235,  ..., 0.3804, 0.3804, 0.3804],
          [0.4235, 0.4235, 0.4235,  ..., 0.3804, 0.3804, 0.3804],
          [0.4235, 0.4235, 0.4235,  ..., 0.3804, 0.3804, 0.3804],
          ...,
          [0.4431, 0.4431, 0.4431,  ..., 0.3608, 0.3608, 0.3608],
          [0.4431, 0.4431, 0.4431,  ..., 0.3608, 0.3608, 0.3608],
          [0.4431, 0.4431, 0.4431,  ..., 0.3608, 0.3608, 0.3608]],

         [[0.3294, 0.3294, 0.3294,  ..., 0.3020, 0.3020, 0.3020],
          [0.3294, 0.3294, 0.3294,  ..., 0.3020, 0.3020, 0.3020],
          [0.3294, 0.3294, 0.3294,  ..., 0



tensor([[[[0.9020, 0.9020, 0.9020,  ..., 0.6902, 0.6902, 0.6902],
          [0.9020, 0.9020, 0.9020,  ..., 0.6902, 0.6902, 0.6902],
          [0.9020, 0.9020, 0.9020,  ..., 0.6902, 0.6902, 0.6902],
          ...,
          [0.1333, 0.1333, 0.1333,  ..., 0.1686, 0.1686, 0.1686],
          [0.1333, 0.1333, 0.1333,  ..., 0.1686, 0.1686, 0.1686],
          [0.1333, 0.1333, 0.1333,  ..., 0.1686, 0.1686, 0.1686]],

         [[0.9569, 0.9569, 0.9569,  ..., 0.6863, 0.6863, 0.6863],
          [0.9569, 0.9569, 0.9569,  ..., 0.6863, 0.6863, 0.6863],
          [0.9569, 0.9569, 0.9569,  ..., 0.6863, 0.6863, 0.6863],
          ...,
          [0.2157, 0.2157, 0.2157,  ..., 0.2627, 0.2627, 0.2627],
          [0.2157, 0.2157, 0.2157,  ..., 0.2627, 0.2627, 0.2627],
          [0.2157, 0.2157, 0.2157,  ..., 0.2627, 0.2627, 0.2627]],

         [[0.9686, 0.9686, 0.9686,  ..., 0.7176, 0.7176, 0.7176],
          [0.9686, 0.9686, 0.9686,  ..., 0.7176, 0.7176, 0.7176],
          [0.9686, 0.9686, 0.9686,  ..., 0



tensor([[[[0.7255, 0.7255, 0.7255,  ..., 0.7216, 0.7216, 0.7216],
          [0.7255, 0.7255, 0.7255,  ..., 0.7216, 0.7216, 0.7216],
          [0.7255, 0.7255, 0.7255,  ..., 0.7216, 0.7216, 0.7216],
          ...,
          [0.2275, 0.2275, 0.2275,  ..., 0.1922, 0.1922, 0.1922],
          [0.2275, 0.2275, 0.2275,  ..., 0.1922, 0.1922, 0.1922],
          [0.2275, 0.2275, 0.2275,  ..., 0.1922, 0.1922, 0.1922]],

         [[0.7725, 0.7725, 0.7725,  ..., 0.7843, 0.7843, 0.7843],
          [0.7725, 0.7725, 0.7725,  ..., 0.7843, 0.7843, 0.7843],
          [0.7725, 0.7725, 0.7725,  ..., 0.7843, 0.7843, 0.7843],
          ...,
          [0.4627, 0.4627, 0.4627,  ..., 0.4235, 0.4235, 0.4235],
          [0.4627, 0.4627, 0.4627,  ..., 0.4235, 0.4235, 0.4235],
          [0.4627, 0.4627, 0.4627,  ..., 0.4235, 0.4235, 0.4235]],

         [[0.8078, 0.8078, 0.8078,  ..., 0.8392, 0.8392, 0.8392],
          [0.8078, 0.8078, 0.8078,  ..., 0.8392, 0.8392, 0.8392],
          [0.8078, 0.8078, 0.8078,  ..., 0



tensor([[[[0.4431, 0.4431, 0.4431,  ..., 0.6745, 0.6745, 0.6745],
          [0.4431, 0.4431, 0.4431,  ..., 0.6745, 0.6745, 0.6745],
          [0.4431, 0.4431, 0.4431,  ..., 0.6745, 0.6745, 0.6745],
          ...,
          [0.1882, 0.1882, 0.1882,  ..., 0.6863, 0.6863, 0.6863],
          [0.1882, 0.1882, 0.1882,  ..., 0.6863, 0.6863, 0.6863],
          [0.1882, 0.1882, 0.1882,  ..., 0.6863, 0.6863, 0.6863]],

         [[0.3843, 0.3843, 0.3843,  ..., 0.6745, 0.6745, 0.6745],
          [0.3843, 0.3843, 0.3843,  ..., 0.6745, 0.6745, 0.6745],
          [0.3843, 0.3843, 0.3843,  ..., 0.6745, 0.6745, 0.6745],
          ...,
          [0.2039, 0.2039, 0.2039,  ..., 0.6667, 0.6667, 0.6667],
          [0.2039, 0.2039, 0.2039,  ..., 0.6667, 0.6667, 0.6667],
          [0.2039, 0.2039, 0.2039,  ..., 0.6667, 0.6667, 0.6667]],

         [[0.2706, 0.2706, 0.2706,  ..., 0.6824, 0.6824, 0.6824],
          [0.2706, 0.2706, 0.2706,  ..., 0.6824, 0.6824, 0.6824],
          [0.2706, 0.2706, 0.2706,  ..., 0



tensor([[[[0.9686, 0.9686, 0.9686,  ..., 0.9608, 0.9608, 0.9608],
          [0.9686, 0.9686, 0.9686,  ..., 0.9608, 0.9608, 0.9608],
          [0.9686, 0.9686, 0.9686,  ..., 0.9608, 0.9608, 0.9608],
          ...,
          [0.8824, 0.8824, 0.8824,  ..., 0.6431, 0.6431, 0.6431],
          [0.8824, 0.8824, 0.8824,  ..., 0.6431, 0.6431, 0.6431],
          [0.8824, 0.8824, 0.8824,  ..., 0.6431, 0.6431, 0.6431]],

         [[0.9647, 0.9647, 0.9647,  ..., 0.9608, 0.9608, 0.9608],
          [0.9647, 0.9647, 0.9647,  ..., 0.9608, 0.9608, 0.9608],
          [0.9647, 0.9647, 0.9647,  ..., 0.9608, 0.9608, 0.9608],
          ...,
          [0.8824, 0.8824, 0.8824,  ..., 0.6431, 0.6431, 0.6431],
          [0.8824, 0.8824, 0.8824,  ..., 0.6431, 0.6431, 0.6431],
          [0.8824, 0.8824, 0.8824,  ..., 0.6431, 0.6431, 0.6431]],

         [[0.9451, 0.9451, 0.9451,  ..., 0.9412, 0.9412, 0.9412],
          [0.9451, 0.9451, 0.9451,  ..., 0.9412, 0.9412, 0.9412],
          [0.9451, 0.9451, 0.9451,  ..., 0



tensor([[[[1.0000, 1.0000, 1.0000,  ..., 0.9882, 0.9882, 0.9882],
          [1.0000, 1.0000, 1.0000,  ..., 0.9882, 0.9882, 0.9882],
          [1.0000, 1.0000, 1.0000,  ..., 0.9882, 0.9882, 0.9882],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 0.9216, 0.9216, 0.9216],
          [1.0000, 1.0000, 1.0000,  ..., 0.9216, 0.9216, 0.9216],
          [1.0000, 1.0000, 1.0000,  ..., 0.9216, 0.9216, 0.9216]],

         [[1.0000, 1.0000, 1.0000,  ..., 0.9882, 0.9882, 0.9882],
          [1.0000, 1.0000, 1.0000,  ..., 0.9882, 0.9882, 0.9882],
          [1.0000, 1.0000, 1.0000,  ..., 0.9882, 0.9882, 0.9882],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 0.9216, 0.9216, 0.9216],
          [1.0000, 1.0000, 1.0000,  ..., 0.9216, 0.9216, 0.9216],
          [1.0000, 1.0000, 1.0000,  ..., 0.9216, 0.9216, 0.9216]],

         [[1.0000, 1.0000, 1.0000,  ..., 0.9882, 0.9882, 0.9882],
          [1.0000, 1.0000, 1.0000,  ..., 0.9882, 0.9882, 0.9882],
          [1.0000, 1.0000, 1.0000,  ..., 0



tensor([[[[0.7373, 0.7373, 0.7373,  ..., 0.6784, 0.6784, 0.6784],
          [0.7373, 0.7373, 0.7373,  ..., 0.6784, 0.6784, 0.6784],
          [0.7373, 0.7373, 0.7373,  ..., 0.6784, 0.6784, 0.6784],
          ...,
          [0.6353, 0.6353, 0.6353,  ..., 0.6196, 0.6196, 0.6196],
          [0.6353, 0.6353, 0.6353,  ..., 0.6196, 0.6196, 0.6196],
          [0.6353, 0.6353, 0.6353,  ..., 0.6196, 0.6196, 0.6196]],

         [[0.7765, 0.7765, 0.7765,  ..., 0.7843, 0.7843, 0.7843],
          [0.7765, 0.7765, 0.7765,  ..., 0.7843, 0.7843, 0.7843],
          [0.7765, 0.7765, 0.7765,  ..., 0.7843, 0.7843, 0.7843],
          ...,
          [0.5765, 0.5765, 0.5765,  ..., 0.5373, 0.5373, 0.5373],
          [0.5765, 0.5765, 0.5765,  ..., 0.5373, 0.5373, 0.5373],
          [0.5765, 0.5765, 0.5765,  ..., 0.5373, 0.5373, 0.5373]],

         [[0.8431, 0.8431, 0.8431,  ..., 0.8745, 0.8745, 0.8745],
          [0.8431, 0.8431, 0.8431,  ..., 0.8745, 0.8745, 0.8745],
          [0.8431, 0.8431, 0.8431,  ..., 0



tensor([[[[0.6000, 0.6000, 0.6000,  ..., 0.8745, 0.8745, 0.8745],
          [0.6000, 0.6000, 0.6000,  ..., 0.8745, 0.8745, 0.8745],
          [0.6000, 0.6000, 0.6000,  ..., 0.8745, 0.8745, 0.8745],
          ...,
          [0.9647, 0.9647, 0.9647,  ..., 0.6824, 0.6824, 0.6824],
          [0.9647, 0.9647, 0.9647,  ..., 0.6824, 0.6824, 0.6824],
          [0.9647, 0.9647, 0.9647,  ..., 0.6824, 0.6824, 0.6824]],

         [[0.5725, 0.5725, 0.5725,  ..., 0.8510, 0.8510, 0.8510],
          [0.5725, 0.5725, 0.5725,  ..., 0.8510, 0.8510, 0.8510],
          [0.5725, 0.5725, 0.5725,  ..., 0.8510, 0.8510, 0.8510],
          ...,
          [0.9686, 0.9686, 0.9686,  ..., 0.6471, 0.6471, 0.6471],
          [0.9686, 0.9686, 0.9686,  ..., 0.6471, 0.6471, 0.6471],
          [0.9686, 0.9686, 0.9686,  ..., 0.6471, 0.6471, 0.6471]],

         [[0.4745, 0.4745, 0.4745,  ..., 0.8157, 0.8157, 0.8157],
          [0.4745, 0.4745, 0.4745,  ..., 0.8157, 0.8157, 0.8157],
          [0.4745, 0.4745, 0.4745,  ..., 0



tensor([[[[0.6000, 0.6000, 0.6000,  ..., 0.5961, 0.5961, 0.5961],
          [0.6000, 0.6000, 0.6000,  ..., 0.5961, 0.5961, 0.5961],
          [0.6000, 0.6000, 0.6000,  ..., 0.5961, 0.5961, 0.5961],
          ...,
          [0.7333, 0.7333, 0.7333,  ..., 0.7373, 0.7373, 0.7373],
          [0.7333, 0.7333, 0.7333,  ..., 0.7373, 0.7373, 0.7373],
          [0.7333, 0.7333, 0.7333,  ..., 0.7373, 0.7373, 0.7373]],

         [[0.5804, 0.5804, 0.5804,  ..., 0.5725, 0.5725, 0.5725],
          [0.5804, 0.5804, 0.5804,  ..., 0.5725, 0.5725, 0.5725],
          [0.5804, 0.5804, 0.5804,  ..., 0.5725, 0.5725, 0.5725],
          ...,
          [0.7020, 0.7020, 0.7020,  ..., 0.7294, 0.7294, 0.7294],
          [0.7020, 0.7020, 0.7020,  ..., 0.7294, 0.7294, 0.7294],
          [0.7020, 0.7020, 0.7020,  ..., 0.7294, 0.7294, 0.7294]],

         [[0.4745, 0.4745, 0.4745,  ..., 0.4471, 0.4471, 0.4471],
          [0.4745, 0.4745, 0.4745,  ..., 0.4471, 0.4471, 0.4471],
          [0.4745, 0.4745, 0.4745,  ..., 0



tensor([[[[0.1137, 0.1137, 0.1137,  ..., 0.1294, 0.1294, 0.1294],
          [0.1137, 0.1137, 0.1137,  ..., 0.1294, 0.1294, 0.1294],
          [0.1137, 0.1137, 0.1137,  ..., 0.1294, 0.1294, 0.1294],
          ...,
          [0.6431, 0.6431, 0.6431,  ..., 0.5373, 0.5373, 0.5373],
          [0.6431, 0.6431, 0.6431,  ..., 0.5373, 0.5373, 0.5373],
          [0.6431, 0.6431, 0.6431,  ..., 0.5373, 0.5373, 0.5373]],

         [[0.1137, 0.1137, 0.1137,  ..., 0.1333, 0.1333, 0.1333],
          [0.1137, 0.1137, 0.1137,  ..., 0.1333, 0.1333, 0.1333],
          [0.1137, 0.1137, 0.1137,  ..., 0.1333, 0.1333, 0.1333],
          ...,
          [0.7020, 0.7020, 0.7020,  ..., 0.6235, 0.6235, 0.6235],
          [0.7020, 0.7020, 0.7020,  ..., 0.6235, 0.6235, 0.6235],
          [0.7020, 0.7020, 0.7020,  ..., 0.6235, 0.6235, 0.6235]],

         [[0.1137, 0.1137, 0.1137,  ..., 0.1176, 0.1176, 0.1176],
          [0.1137, 0.1137, 0.1137,  ..., 0.1176, 0.1176, 0.1176],
          [0.1137, 0.1137, 0.1137,  ..., 0



[1;30;43mStreaming output truncated to the last 5000 lines.[0m

         [[0.0784, 0.0784, 0.0784,  ..., 0.0941, 0.0941, 0.0941],
          [0.0784, 0.0784, 0.0784,  ..., 0.0941, 0.0941, 0.0941],
          [0.0784, 0.0784, 0.0784,  ..., 0.0941, 0.0941, 0.0941],
          ...,
          [0.5922, 0.5922, 0.5922,  ..., 0.7294, 0.7294, 0.7294],
          [0.5922, 0.5922, 0.5922,  ..., 0.7294, 0.7294, 0.7294],
          [0.5922, 0.5922, 0.5922,  ..., 0.7294, 0.7294, 0.7294]],

         [[0.0549, 0.0549, 0.0549,  ..., 0.0667, 0.0667, 0.0667],
          [0.0549, 0.0549, 0.0549,  ..., 0.0667, 0.0667, 0.0667],
          [0.0549, 0.0549, 0.0549,  ..., 0.0667, 0.0667, 0.0667],
          ...,
          [0.5922, 0.5922, 0.5922,  ..., 0.7294, 0.7294, 0.7294],
          [0.5922, 0.5922, 0.5922,  ..., 0.7294, 0.7294, 0.7294],
          [0.5922, 0.5922, 0.5922,  ..., 0.7294, 0.7294, 0.7294]]],


        [[[0.1529, 0.1529, 0.1529,  ..., 0.1490, 0.1490, 0.1490],
          [0.1529, 0.1529, 0.1529,  ...,

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(test_dataset[50][0].T)
plt.show()

In [None]:
train_dataset.classes_to_idx

In [None]:
x_output = vae(test_dataset[50][0].unsqueeze(0).to(device))

In [None]:
decoded_image = x_output[3].cpu().detach()[0]

In [None]:
decoded_image = decoded_image.type(torch.uint8)

In [None]:
decoded_image

In [None]:
imgplot = plt.imshow(decoded_image.T)
plt.show()

In [None]:
wandb.finish()

---

In [377]:
!git clone https://github.com/beamaia/HistoVAE/

Cloning into 'HistoVAE'...
remote: Enumerating objects: 379, done.[K
remote: Total 379 (delta 0), reused 0 (delta 0), pack-reused 379[K
Receiving objects: 100% (379/379), 4.64 MiB | 14.57 MiB/s, done.
Resolving deltas: 100% (224/224), done.


In [378]:
!pwd

/content


In [386]:
!ls histo

dependencies  figures	README.md	  scripts  txt	  VAE
diagrams      Makefile	requirements.txt  src	   utils


In [387]:
!cd histo

In [388]:
!ls

data  drive  histo  sample_data  wandb


In [389]:
!python histo/scripts/train.py --model_type VariationalConvolutionalAutoencoder

Traceback (most recent call last):
  File "/content/histo/scripts/train.py", line 9, in <module>
    from src.classes import Dataset
ModuleNotFoundError: No module named 'src'


In [400]:
'''This script demonstrates how to build a variational autoencoder
with Keras and deconvolution layers.
# Reference
- Auto-Encoding Variational Bayes
  https://arxiv.org/abs/1312.6114
'''
import numpy as np
from scipy.stats import norm

from keras.layers import Input, Dense, Lambda, Flatten, Reshape, Layer
from keras.layers import Conv2D, Conv2DTranspose
from keras.models import Model
from keras import backend as K
from keras import metrics
from keras.datasets import mnist, cifar10
import sys
rootdir = '/content/'
sys.path.append(rootdir)
from skimage.transform import resize
from sklearn.model_selection import train_test_split
from keras import optimizers
from keras.callbacks import TensorBoard, EarlyStopping, ModelCheckpoint



np.random.seed(42)

# number of convolutional filters to use
filters = 64
# convolution kernel size
num_conv = 3
batch_size = 32


dataset = 'mnist'

if dataset == 'mnist':
    img_rows, img_cols, img_chns = 28, 28, 1
elif dataset == 'cifar10':
    img_rows, img_cols, img_chns = 32, 32, 3
elif dataset == 'histo-dev':
    img_rows, img_cols, img_chns = 256, 256, 3
elif dataset == '128-patches':
    img_rows, img_cols, img_chns = 128, 128, 3

if K.image_data_format() == 'channels_first':
    original_img_size = (img_chns, img_rows, img_cols)
else:
    original_img_size = (img_rows, img_cols, img_chns)




latent_dim = 16
intermediate_dim = 512
epsilon_std = 1.0
epochs = 100
lr = 0.0001

x = Input(shape=original_img_size)
conv_1 = Conv2D(img_chns,
                kernel_size=(2, 2),
                padding='same', activation='relu')(x)
conv_2 = Conv2D(filters,
                kernel_size=(2, 2),
                padding='same', activation='relu',
                strides=(2, 2))(conv_1)
conv_3 = Conv2D(filters,
                kernel_size=num_conv,
                padding='same', activation='relu',
                strides=1)(conv_2)
conv_4 = Conv2D(filters,
                kernel_size=num_conv,
                padding='same', activation='relu',
                strides=1)(conv_3)
flat = Flatten()(conv_4)
hidden = Dense(intermediate_dim, activation='relu')(flat)

z_mean = Dense(latent_dim)(hidden)
z_log_var = Dense(latent_dim)(hidden)


def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),
                              mean=0., stddev=epsilon_std)
    return z_mean + K.exp(z_log_var) * epsilon


# note that "output_shape" isn't necessary with the TensorFlow backend
# so you could write `Lambda(sampling)([z_mean, z_log_var])`
print(sampling)
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

# we instantiate these layers separately so as to reuse them later


decoder_hid = Dense(intermediate_dim, activation='relu')
decoder_upsample = Dense(filters * int(img_rows/2) * int(img_rows/2), activation='relu')

if K.image_data_format() == 'channels_first':
    output_shape = (batch_size, filters, int(img_rows/2), int(img_rows/2))
else:
    output_shape = (batch_size, int(img_rows/2), int(img_rows/2), filters)

decoder_reshape = Reshape(output_shape[1:])
decoder_deconv_1 = Conv2DTranspose(filters,
                                   kernel_size=num_conv,
                                   padding='same',
                                   strides=1,
                                   activation='relu')
decoder_deconv_2 = Conv2DTranspose(filters,
                                   kernel_size=num_conv,
                                   padding='same',
                                   strides=1,
                                   activation='relu')
if K.image_data_format() == 'channels_first':
    output_shape = (batch_size, filters, 29, 29)
else:
    output_shape = (batch_size, 29, 29, filters)
decoder_deconv_3_upsamp = Conv2DTranspose(filters,
                                          kernel_size=(3, 3),
                                          strides=(2, 2),
                                          padding='valid',
                                          activation='relu')
decoder_mean_squash = Conv2D(img_chns,
                             kernel_size=2,
                             padding='valid',
                             activation='sigmoid')

hid_decoded = decoder_hid(z)
up_decoded = decoder_upsample(hid_decoded)
reshape_decoded = decoder_reshape(up_decoded)
deconv_1_decoded = decoder_deconv_1(reshape_decoded)
deconv_2_decoded = decoder_deconv_2(deconv_1_decoded)
x_decoded_relu = decoder_deconv_3_upsamp(deconv_2_decoded)
x_decoded_mean_squash = decoder_mean_squash(x_decoded_relu)




# Custom loss layer
class CustomVariationalLayer(Layer):
    def __init__(self, **kwargs):
        self.is_placeholder = True
        super(CustomVariationalLayer, self).__init__(**kwargs)

    def vae_loss(self, x, x_decoded_mean_squash):
        print("inside loss?")
        print(x, x_decoded_mean_squash)
        x = K.flatten(x)
        x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)

        print(x, x_decoded_mean_squash)
        xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
        kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        print(xent_loss, kl_loss)
        return K.mean(xent_loss + kl_loss)

    def call(self, inputs):
        x = inputs[0]
        x_decoded_mean_squash = inputs[1]
        loss = self.vae_loss(x, x_decoded_mean_squash)
        self.add_loss(loss, inputs=inputs)
        # We don't use this output.
        return x



y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)

rmsprop = optimizers.legacy.RMSprop(learning_rate=lr, rho=0.9, epsilon=1e-08, decay=0.0)
# sgd = optimizers.sgd(lr=0.01, momentum=0.0, decay=0.0, nesterov=False)
vae.compile(optimizer=rmsprop, loss=None)



if __name__ == '__main__':
    vae.summary()
    if dataset == 'mnist':
        (x_train, _), (x_test, y_test) = mnist.load_data()

        x_train = x_train.astype('float32') / 255.
        x_train = x_train.reshape((x_train.shape[0],) + original_img_size)
        x_test = x_test.astype('float32') / 255.
        x_test = x_test.reshape((x_test.shape[0],) + original_img_size)



        print('x_train.shape:', x_train.shape)

    elif dataset == 'histo-dev':
        import h5py
        filename = '/hps/nobackup/research/stegle/users/willj/GTEx/data/h5py/patches-all_s-10_p-50.hdf5'
        with h5py.File(filename) as f:
            patches = f['/patches'].value
            labels = f['/labels'].value

        patches = [resize(x, (256, 256)) for x in patches]
        x_train, x_test, y_train, y_test = train_test_split(patches, labels, test_size=0.4, random_state=42)
        x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=0.5, random_state=42)

        x_train = np.array(x_train, dtype='float32')
        x_test = np.array(x_test, dtype='float32')
        x_val = np.array(x_val, dtype='float32')

        # import pdb; pdb.set_trace()

        print('x_train.shape:', x_train.shape)

    elif dataset == 'cifar10':
        (x_train, _), (x_test, y_test) = cifar10.load_data()
        x_train = x_train.astype('float32') / 255.
        x_test = x_test.astype('float32') / 255.

    if dataset == '128-patches':
        import h5py


        filename = '/hps/nobackup/research/stegle/users/willj/GTEx/data/patches/Lung/GTEX-144GM-0126_128.hdf5'
        with h5py.File(filename) as f:
            patches1 = f['/patches'].value
            labels1 = ['144GM'] * len(patches1)

        filename = '/hps/nobackup/research/stegle/users/willj/GTEx/data/patches/Lung/GTEX-13N1W-0726_128.hdf5'
        with h5py.File(filename) as f:
            patches2 = f['/patches'].value
            labels2 = ['13N1W'] * len(patches2)

        patches = np.concatenate([patches1, patches2])
        labels = labels1 + labels2

        x_train, x_test, y_train, y_test = train_test_split(patches, labels, test_size=0.4, random_state=42)
        x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=0.5, random_state=42)

        x_train = np.array(x_train, dtype='float32') / 255.
        x_test = np.array(x_test, dtype='float32') / 255.
        x_val = np.array(x_val, dtype='float32') / 255.
        import pdb; pdb.set_trace()


    print(x_train)
    vae.fit(x_train, shuffle=True, epochs=epochs, batch_size=batch_size)

    vae.save('models/vae-{dataset}.hdf5'.format(dataset=dataset))



    # build a model to project inputs on the latent space
    encoder = Model(x, z_mean)
    encoder.save('models/encoder-{dataset}.hdf5'.format(dataset=dataset))

    # display a 2D plot of the digit classes in the latent space
    x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
    # plt.figure(figsize=(6, 6))
    # plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
    # plt.colorbar()
    # plt.show()

    # build a digit generator that can sample from the learned distribution
    decoder_input = Input(shape=(latent_dim,))
    _hid_decoded = decoder_hid(decoder_input)
    _up_decoded = decoder_upsample(_hid_decoded)
    _reshape_decoded = decoder_reshape(_up_decoded)
    _deconv_1_decoded = decoder_deconv_1(_reshape_decoded)
    _deconv_2_decoded = decoder_deconv_2(_deconv_1_decoded)
    _x_decoded_relu = decoder_deconv_3_upsamp(_deconv_2_decoded)
    _x_decoded_mean_squash = decoder_mean_squash(_x_decoded_relu)
    generator = Model(decoder_input, _x_decoded_mean_squash)
    generator.save('models/generator-{dataset}.hdf5'.format(dataset=dataset))


<function sampling at 0x7f97582cc940>
inside loss?
Tensor("Placeholder:0", shape=(None, 28, 28, 1), dtype=float32) Tensor("Placeholder_1:0", shape=(None, 28, 28, 1), dtype=float32)
Tensor("custom_variational_layer_10/Reshape:0", shape=(None,), dtype=float32) Tensor("custom_variational_layer_10/Reshape_1:0", shape=(None,), dtype=float32)
Tensor("custom_variational_layer_10/mul_2:0", shape=(), dtype=float32) KerasTensor(type_spec=TensorSpec(shape=(None,), dtype=tf.float32, name=None), name='tf.math.multiply_19/Mul:0', description="created by layer 'tf.math.multiply_19'")
Model: "model_10"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_11 (InputLayer)       [(None, 28, 28, 1)]          0         []                            
                                                                                                  
 conv2d_50

TypeError: in user code:

    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 1401, in train_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 1384, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 1373, in run_step  **
        outputs = model.train_step(data)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 1151, in train_step
        loss = self.compute_loss(x, y, y_pred, sample_weight)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 1209, in compute_loss
        return self.compiled_loss(
    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/compile_utils.py", line 329, in __call__
        self._total_loss_mean.update_state(
    File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/metrics_utils.py", line 77, in decorated
        result = update_state_fn(*args, **kwargs)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/metrics/base_metric.py", line 140, in update_state_fn
        return ag_update_state(*args, **kwargs)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/metrics/base_metric.py", line 528, in update_state  **
        update_total_op = self.total.assign_add(value_sum)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/keras_tensor.py", line 285, in __array__
        raise TypeError(

    TypeError: You are passing KerasTensor(type_spec=TensorSpec(shape=(), dtype=tf.float32, name=None), name='tf.math.reduce_sum_9/Sum:0', description="created by layer 'tf.math.reduce_sum_9'"), an intermediate Keras symbolic input/output, to a TF API that does not allow registering custom dispatchers, such as `tf.cond`, `tf.function`, gradient tapes, or `tf.map_fn`. Keras Functional model construction only supports TF API calls that *do* support dispatching, such as `tf.math.add` or `tf.reshape`. Other APIs cannot be called directly on symbolic Kerasinputs/outputs. You can work around this limitation by putting the operation in a custom Keras layer `call` and calling that layer on this symbolic input/output.
