In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Model, GPT2LMHeadModel
import imageio, os
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm

run_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# video settings
RESOLUTION_WIDTH = 128
RESOLUTION_HEIGHT = 128
CHANNELS = 3
CONVERTED_FRAMERATE = 24

# model settings
WINDOW_SIZE = 48
ENCODED_DIM = 768

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1)
        )
    def forward(self, x):
        return x + self.block(x)

class ConvAutoencoder(nn.Module):
    def __init__(self, in_channels=CHANNELS, latent_dim=ENCODED_DIM, input_resolution=(RESOLUTION_WIDTH, RESOLUTION_HEIGHT)):
        super().__init__()
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 4, 2, 1),  # 64x64
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),           # 32x32
            nn.ReLU(),
            ResidualBlock(64),
            nn.Conv2d(64, 128, 4, 2, 1),          # 16x16
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1),         # 8x8
            nn.ReLU()
        )

        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, *input_resolution)
            enc_out = self.encoder(dummy)
            self.flattened_size = enc_out.view(1, -1).shape[1]

        self.encoder_fc = nn.Linear(self.flattened_size, latent_dim)
        self.decoder_fc = nn.Linear(latent_dim, self.flattened_size)

        self.decoder = nn.Sequential(
            nn.Unflatten(1, enc_out.shape[1:]),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 32x32
            nn.ReLU(),
            ResidualBlock(64),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),    # 64x64
            nn.ReLU(),
            nn.ConvTranspose2d(32, in_channels, 4, 2, 1),  # 128x128
            nn.Tanh()
        )

    def encode(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        x = F.gelu(self.encoder_fc(x))
        
        return x

    def decode(self, z):
        z = F.gelu(self.decoder_fc(z))
        z = self.decoder(z)
        
        return z

In [4]:
class ImageProcessor:
    def tensor_to_pil(self, image_tensor: torch.Tensor) -> Image.Image:
        """
        Convert a tensor to a PIL Image.
        
        Args:
            image_tensor (torch.Tensor): A tensor of shape (C, H, W) with pixel values in the range [0, 1].
        
        Returns:
            Image.Image: A PIL Image object.
        """
        # Clamp to [-1, 1], convert to [0, 255] and uint8
        image_np = (image_tensor.clamp(-1, 1).mul(255).byte().cpu().permute(1, 2, 0).numpy())
        return Image.fromarray(image_np)
    
    def pil_to_tensor(self, image: Image.Image) -> torch.Tensor:
        """
        Convert a PIL image to a PyTorch tensor of shape (C, H, W) with values in [0, 1].
        
        Args:
            image (Image.Image): A PIL Image object.
        
        Returns:
            torch.Tensor: A tensor of shape (C, H, W) with pixel values in the range [0, 1].
        """
        return transforms.ToTensor()(image) * 2 - 1  # Already returns (C, H, W)

In [5]:
# Load the autoencoder
autoencoder = ConvAutoencoder()
autoenced_state_dict = torch.load("checkpoints/run1/autoenc.pth", map_location=run_device)
autoencoder.load_state_dict(autoenced_state_dict)
autoencoder = autoencoder.to(run_device).eval()

# load the transformer
transformer = GPT2LMHeadModel.from_pretrained("checkpoints/run1/gpt2_decap").transformer
transformer = transformer.to(run_device).eval()

# Load image processor
proc = ImageProcessor()

In [6]:
def generate_frames(num_frames, context_length=WINDOW_SIZE, autoencoder=autoencoder, transformer=transformer):
    autoencoder.eval()
    transformer.eval()
    
    total_seq = torch.zeros(num_frames + context_length, CHANNELS, RESOLUTION_HEIGHT, RESOLUTION_WIDTH, device=run_device)

    with torch.no_grad():
        for i in tqdm(range(num_frames)):
            current_slice = total_seq[i:i + context_length]
            
            slice_latents = autoencoder.encode(current_slice)
            
            print(torch.max(slice_latents))
            
            #slice_latents += torch.rand(slice_latents.shape, device=run_device) * 1000
            
            prediction_latents = transformer(inputs_embeds=slice_latents.unsqueeze(0)).last_hidden_state
            
            prediction_frame = autoencoder.decode(prediction_latents.squeeze(0))  # shape: (context_length, C, H, W)
            
            total_seq[context_length + i] = prediction_frame[-1]
    
    return total_seq

In [7]:
def save_video(frames, output='output.mp4', fps=CONVERTED_FRAMERATE):
    writer = imageio.get_writer(output, fps=fps)
    for frame in tqdm(frames):
        img = frame.permute(1,2,0).numpy()
        img = ((img + 1)/2 * 255).astype('uint8')
        writer.append_data(img)
    writer.close()
    print(f'Saved {output}')

In [None]:
init_img = proc.pil_to_tensor(Image.open("test.png").convert('RGB').resize((RESOLUTION_WIDTH, RESOLUTION_HEIGHT))).to(run_device)

frames = generate_frames(512).detach().cpu()
torch.cuda.empty_cache()
save_video(frames)

  1%|          | 4/512 [00:00<00:43, 11.66it/s]

tensor(51127.3789, device='cuda:0')
tensor(52748.7148, device='cuda:0')
tensor(55248.2461, device='cuda:0')
tensor(57826.1289, device='cuda:0')


  2%|▏         | 8/512 [00:00<00:26, 19.32it/s]

tensor(60337.6172, device='cuda:0')
tensor(62722.3398, device='cuda:0')
tensor(64966.9102, device='cuda:0')
tensor(67028.8906, device='cuda:0')
tensor(68891.7266, device='cuda:0')
tensor(70470.0625, device='cuda:0')
tensor(71720.2188, device='cuda:0')


  3%|▎         | 16/512 [00:00<00:17, 28.87it/s]

tensor(72699.4531, device='cuda:0')
tensor(73443.4766, device='cuda:0')
tensor(74004.0469, device='cuda:0')
tensor(74419.7422, device='cuda:0')
tensor(74726.8516, device='cuda:0')
tensor(74956.5625, device='cuda:0')
tensor(75128.3594, device='cuda:0')
tensor(75258.0625, device='cuda:0')


  5%|▍         | 24/512 [00:00<00:15, 32.14it/s]

tensor(75355.7812, device='cuda:0')
tensor(75430.0547, device='cuda:0')
tensor(75486.9766, device='cuda:0')
tensor(75531.2031, device='cuda:0')
tensor(75565.8203, device='cuda:0')
tensor(75593.2031, device='cuda:0')
tensor(75615.0547, device='cuda:0')
tensor(75632.6719, device='cuda:0')


  6%|▋         | 32/512 [00:01<00:14, 34.04it/s]

tensor(75647.1875, device='cuda:0')
tensor(75659.6250, device='cuda:0')
tensor(75670.6562, device='cuda:0')
tensor(75680.8125, device='cuda:0')
tensor(75690.6094, device='cuda:0')
tensor(75700.3984, device='cuda:0')
tensor(75710.2266, device='cuda:0')


  8%|▊         | 40/512 [00:01<00:13, 35.22it/s]

tensor(75720.1953, device='cuda:0')
tensor(75730.4375, device='cuda:0')
tensor(75741.1250, device='cuda:0')
tensor(75752.3516, device='cuda:0')
tensor(75764.0703, device='cuda:0')
tensor(75776.4922, device='cuda:0')
tensor(75789.6016, device='cuda:0')


  9%|▉         | 48/512 [00:01<00:13, 35.62it/s]

tensor(75802.9453, device='cuda:0')
tensor(75816.3672, device='cuda:0')
tensor(75829.7578, device='cuda:0')
tensor(75842.9531, device='cuda:0')
tensor(75855.7500, device='cuda:0')
tensor(75867.6172, device='cuda:0')
tensor(75878.2891, device='cuda:0')
tensor(75887.3828, device='cuda:0')


 11%|█         | 56/512 [00:01<00:12, 36.64it/s]

tensor(75898.6016, device='cuda:0')
tensor(75946.0156, device='cuda:0')
tensor(76015.6016, device='cuda:0')
tensor(76093.7656, device='cuda:0')
tensor(76173.1875, device='cuda:0')
tensor(76247.4922, device='cuda:0')
tensor(76316.4062, device='cuda:0')
tensor(76380.1484, device='cuda:0')
tensor(76437.3047, device='cuda:0')


 12%|█▎        | 64/512 [00:02<00:12, 36.38it/s]

tensor(76487.5078, device='cuda:0')
tensor(76530.2969, device='cuda:0')
tensor(76565.9766, device='cuda:0')
tensor(76594.8984, device='cuda:0')
tensor(76617.7422, device='cuda:0')
tensor(76635.4297, device='cuda:0')
tensor(76648.8672, device='cuda:0')
tensor(76658.8438, device='cuda:0')


 14%|█▍        | 72/512 [00:02<00:12, 35.94it/s]

tensor(76666.0078, device='cuda:0')
tensor(76670.8516, device='cuda:0')
tensor(76673.7891, device='cuda:0')
tensor(76675.1328, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')


 16%|█▌        | 80/512 [00:02<00:11, 36.35it/s]

tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')


 17%|█▋        | 88/512 [00:02<00:11, 36.64it/s]

tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')


 19%|█▉        | 96/512 [00:02<00:11, 36.72it/s]

tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')


 20%|██        | 104/512 [00:03<00:10, 37.72it/s]

tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')


 22%|██▏       | 112/512 [00:03<00:11, 36.07it/s]

tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')


 23%|██▎       | 120/512 [00:03<00:11, 34.79it/s]

tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76675.1484, device='cuda:0')
tensor(76674.1328, device='cuda:0')
tensor(76672.2109, device='cuda:0')


 24%|██▍       | 124/512 [00:03<00:11, 33.56it/s]

tensor(76669.5234, device='cuda:0')
tensor(76666.1250, device='cuda:0')
tensor(76662.2188, device='cuda:0')
tensor(76657.8984, device='cuda:0')
tensor(76653.2188, device='cuda:0')
tensor(76648.2344, device='cuda:0')
tensor(76642.9922, device='cuda:0')


 26%|██▌       | 132/512 [00:03<00:11, 34.14it/s]

tensor(76637.5781, device='cuda:0')
tensor(76632.0234, device='cuda:0')
tensor(76626.3516, device='cuda:0')
tensor(76620.5469, device='cuda:0')
tensor(76614.6328, device='cuda:0')
tensor(76608.6719, device='cuda:0')
tensor(76602.5938, device='cuda:0')
tensor(76596.4844, device='cuda:0')


 27%|██▋       | 140/512 [00:04<00:10, 34.34it/s]

tensor(76590.3203, device='cuda:0')
tensor(76584.1094, device='cuda:0')
tensor(76577.8594, device='cuda:0')
tensor(76571.5312, device='cuda:0')
tensor(76565.1250, device='cuda:0')
tensor(76558.5234, device='cuda:0')
tensor(76551.7578, device='cuda:0')
tensor(76544.9219, device='cuda:0')


 29%|██▉       | 148/512 [00:04<00:10, 34.42it/s]

tensor(76538.0312, device='cuda:0')
tensor(76531.1797, device='cuda:0')
tensor(76524.5078, device='cuda:0')
tensor(76518.0625, device='cuda:0')
tensor(76511.9453, device='cuda:0')
tensor(76506.1719, device='cuda:0')
tensor(76500.7500, device='cuda:0')


 30%|███       | 156/512 [00:04<00:10, 34.22it/s]

tensor(76495.6641, device='cuda:0')
tensor(76490.8984, device='cuda:0')
tensor(76486.4609, device='cuda:0')
tensor(76482.2422, device='cuda:0')
tensor(76478.1484, device='cuda:0')
tensor(76474.1328, device='cuda:0')
tensor(76470.0703, device='cuda:0')
tensor(76466.0156, device='cuda:0')


 32%|███▏      | 164/512 [00:04<00:10, 34.10it/s]

tensor(76461.9688, device='cuda:0')
tensor(76457.8906, device='cuda:0')
tensor(76453.8359, device='cuda:0')
tensor(76449.7031, device='cuda:0')
tensor(76445.5312, device='cuda:0')
tensor(76441.3438, device='cuda:0')
tensor(76437.1250, device='cuda:0')


 34%|███▎      | 172/512 [00:05<00:10, 33.96it/s]

tensor(76432.8672, device='cuda:0')
tensor(76428.5781, device='cuda:0')
tensor(76424.2734, device='cuda:0')
tensor(76419.9297, device='cuda:0')
tensor(76415.5625, device='cuda:0')
tensor(76411.1562, device='cuda:0')
tensor(76406.7109, device='cuda:0')


 34%|███▍      | 176/512 [00:05<00:09, 33.77it/s]

tensor(76402.2500, device='cuda:0')
tensor(76397.7812, device='cuda:0')
tensor(76393.3203, device='cuda:0')
tensor(76388.8672, device='cuda:0')
tensor(76384.4297, device='cuda:0')
tensor(76380.0234, device='cuda:0')


 36%|███▌      | 184/512 [00:05<00:09, 33.86it/s]

tensor(76375.6484, device='cuda:0')
tensor(76371.3281, device='cuda:0')
tensor(76367.0703, device='cuda:0')
tensor(76362.8672, device='cuda:0')
tensor(76358.6875, device='cuda:0')
tensor(76354.5391, device='cuda:0')
tensor(76350.4531, device='cuda:0')
tensor(76346.4062, device='cuda:0')


 38%|███▊      | 192/512 [00:05<00:09, 34.05it/s]

tensor(76342.3984, device='cuda:0')
tensor(76338.4141, device='cuda:0')
tensor(76334.4453, device='cuda:0')
tensor(76330.5156, device='cuda:0')
tensor(76326.6641, device='cuda:0')
tensor(76322.8906, device='cuda:0')
tensor(76319.2422, device='cuda:0')
tensor(76315.6953, device='cuda:0')


 39%|███▉      | 200/512 [00:05<00:09, 34.24it/s]

tensor(76312.2344, device='cuda:0')
tensor(76308.8984, device='cuda:0')
tensor(76305.6328, device='cuda:0')
tensor(76302.4922, device='cuda:0')
tensor(76299.4531, device='cuda:0')
tensor(76296.4844, device='cuda:0')
tensor(76293.5859, device='cuda:0')
tensor(76290.7422, device='cuda:0')


 41%|████      | 208/512 [00:06<00:08, 33.99it/s]

tensor(76287.9844, device='cuda:0')
tensor(76285.2734, device='cuda:0')
tensor(76282.6250, device='cuda:0')
tensor(76280.0156, device='cuda:0')
tensor(76277.4688, device='cuda:0')
tensor(76274.9766, device='cuda:0')
tensor(76272.5312, device='cuda:0')
tensor(76270.1484, device='cuda:0')


 42%|████▏     | 216/512 [00:06<00:08, 33.85it/s]

tensor(76267.7969, device='cuda:0')
tensor(76265.5078, device='cuda:0')
tensor(76263.2188, device='cuda:0')
tensor(76260.9688, device='cuda:0')
tensor(76258.7422, device='cuda:0')
tensor(76256.5625, device='cuda:0')
tensor(76254.3984, device='cuda:0')


 44%|████▍     | 224/512 [00:06<00:08, 32.89it/s]

tensor(76252.2578, device='cuda:0')
tensor(76250.1641, device='cuda:0')
tensor(76248.1016, device='cuda:0')
tensor(76246.0547, device='cuda:0')
tensor(76244.0234, device='cuda:0')
tensor(76242.0469, device='cuda:0')
tensor(76240.1172, device='cuda:0')


 45%|████▍     | 228/512 [00:06<00:08, 32.82it/s]

tensor(76238.2344, device='cuda:0')
tensor(76236.3672, device='cuda:0')
tensor(76234.5781, device='cuda:0')
tensor(76232.8203, device='cuda:0')
tensor(76231.1094, device='cuda:0')
tensor(76229.4375, device='cuda:0')


 46%|████▌     | 236/512 [00:07<00:08, 32.70it/s]

tensor(76227.7734, device='cuda:0')
tensor(76226.1328, device='cuda:0')
tensor(76224.5234, device='cuda:0')
tensor(76222.9453, device='cuda:0')
tensor(76221.3984, device='cuda:0')
tensor(76219.8594, device='cuda:0')
tensor(76218.3594, device='cuda:0')


 48%|████▊     | 244/512 [00:07<00:08, 33.07it/s]

tensor(76216.8516, device='cuda:0')
tensor(76215.3750, device='cuda:0')
tensor(76213.9453, device='cuda:0')
tensor(76212.5156, device='cuda:0')
tensor(76211.1094, device='cuda:0')
tensor(76209.7344, device='cuda:0')
tensor(76208.3984, device='cuda:0')


 48%|████▊     | 248/512 [00:07<00:07, 33.57it/s]

tensor(76207.0625, device='cuda:0')
tensor(76205.7500, device='cuda:0')
tensor(76204.4688, device='cuda:0')
tensor(76203.2031, device='cuda:0')
tensor(76201.9453, device='cuda:0')
tensor(76200.7266, device='cuda:0')
tensor(76199.5156, device='cuda:0')


 50%|█████     | 256/512 [00:07<00:07, 32.81it/s]

tensor(76198.3359, device='cuda:0')
tensor(76197.1562, device='cuda:0')
tensor(76196.0234, device='cuda:0')
tensor(76194.8750, device='cuda:0')
tensor(76193.7422, device='cuda:0')
tensor(76192.6094, device='cuda:0')
tensor(76191.5078, device='cuda:0')


 52%|█████▏    | 264/512 [00:07<00:07, 31.89it/s]

tensor(76190.4219, device='cuda:0')
tensor(76189.3359, device='cuda:0')
tensor(76188.2422, device='cuda:0')
tensor(76187.1719, device='cuda:0')
tensor(76186.1016, device='cuda:0')
tensor(76185.0391, device='cuda:0')
tensor(76183.9766, device='cuda:0')


 53%|█████▎    | 272/512 [00:08<00:07, 31.24it/s]

tensor(76182.9219, device='cuda:0')
tensor(76181.8906, device='cuda:0')
tensor(76180.8438, device='cuda:0')
tensor(76179.8359, device='cuda:0')
tensor(76178.8516, device='cuda:0')
tensor(76177.8594, device='cuda:0')
tensor(76176.8750, device='cuda:0')


 54%|█████▍    | 276/512 [00:08<00:07, 31.54it/s]

tensor(76175.9219, device='cuda:0')
tensor(76174.9609, device='cuda:0')
tensor(76174.0312, device='cuda:0')
tensor(76173.1016, device='cuda:0')
tensor(76172.1953, device='cuda:0')
tensor(76171.2891, device='cuda:0')
tensor(76170.3906, device='cuda:0')


 55%|█████▌    | 284/512 [00:08<00:07, 32.50it/s]

tensor(76169.5156, device='cuda:0')
tensor(76168.6719, device='cuda:0')
tensor(76167.8203, device='cuda:0')
tensor(76166.9766, device='cuda:0')
tensor(76166.1484, device='cuda:0')
tensor(76165.3359, device='cuda:0')
tensor(76164.5234, device='cuda:0')


 57%|█████▋    | 292/512 [00:08<00:06, 31.84it/s]

tensor(76163.7188, device='cuda:0')
tensor(76162.9375, device='cuda:0')
tensor(76162.1719, device='cuda:0')
tensor(76161.3984, device='cuda:0')
tensor(76160.6406, device='cuda:0')
tensor(76159.8906, device='cuda:0')
tensor(76159.1641, device='cuda:0')


 59%|█████▊    | 300/512 [00:09<00:06, 33.61it/s]

tensor(76158.4219, device='cuda:0')
tensor(76157.7031, device='cuda:0')
tensor(76157., device='cuda:0')
tensor(76156.3125, device='cuda:0')
tensor(76155.6172, device='cuda:0')
tensor(76154.9531, device='cuda:0')
tensor(76154.2812, device='cuda:0')
tensor(76153.6094, device='cuda:0')


 60%|██████    | 308/512 [00:09<00:06, 33.15it/s]

tensor(76152.9844, device='cuda:0')
tensor(76152.3516, device='cuda:0')
tensor(76151.7188, device='cuda:0')
tensor(76151.1094, device='cuda:0')
tensor(76150.4922, device='cuda:0')
tensor(76149.8828, device='cuda:0')
tensor(76149.2812, device='cuda:0')


 61%|██████    | 312/512 [00:09<00:06, 32.96it/s]

tensor(76148.7031, device='cuda:0')
tensor(76148.1172, device='cuda:0')
tensor(76147.5547, device='cuda:0')
tensor(76147., device='cuda:0')
tensor(76146.4375, device='cuda:0')
tensor(76145.8828, device='cuda:0')
tensor(76145.3438, device='cuda:0')


 62%|██████▎   | 320/512 [00:09<00:05, 32.87it/s]

In [None]:
1 / 0

ZeroDivisionError: division by zero

In [None]:
frames.shape

torch.Size([560, 3, 128, 128])

In [None]:
img = Image.open("test.png").convert('RGB').resize((RESOLUTION_WIDTH, RESOLUTION_HEIGHT))

In [None]:
enc_img = proc.pil_to_tensor(img).to(run_device).unsqueeze(0)

In [None]:
latent = autoenc.encode(enc_img)

NameError: name 'autoenc' is not defined

In [None]:
prediction = transformer(inputs_embeds=latent).last_hidden_state

In [None]:
decoded = autoenc.decode(prediction)

In [None]:
proc.tensor_to_pil(decoded.squeeze(0))

In [None]:
init_img

tensor([[[0.4941, 0.4980, 0.4980,  ..., 0.4157, 0.4157, 0.4157],
         [0.4980, 0.5020, 0.4980,  ..., 0.4196, 0.4196, 0.4196],
         [0.4980, 0.5020, 0.5059,  ..., 0.4196, 0.4157, 0.4196],
         ...,
         [0.4314, 0.3804, 0.4667,  ..., 0.2784, 0.3804, 0.3686],
         [0.6471, 0.6235, 0.5843,  ..., 0.2627, 0.2706, 0.2196],
         [0.4588, 0.4196, 0.2902,  ..., 0.3020, 0.2784, 0.2784]],

        [[0.6745, 0.6745, 0.6784,  ..., 0.6196, 0.6196, 0.6235],
         [0.6784, 0.6824, 0.6784,  ..., 0.6275, 0.6275, 0.6275],
         [0.6784, 0.6824, 0.6863,  ..., 0.6275, 0.6235, 0.6275],
         ...,
         [0.4510, 0.4039, 0.4980,  ..., 0.3294, 0.4275, 0.4039],
         [0.6863, 0.6706, 0.6353,  ..., 0.3216, 0.3255, 0.2706],
         [0.5137, 0.4667, 0.3176,  ..., 0.3765, 0.3490, 0.3529]],

        [[0.8392, 0.8392, 0.8392,  ..., 0.8392, 0.8353, 0.8314],
         [0.8353, 0.8353, 0.8392,  ..., 0.8431, 0.8392, 0.8353],
         [0.8353, 0.8392, 0.8431,  ..., 0.8392, 0.8392, 0.