In [1]:
# import transformer_lens
# import transformer_lens.utils as utils
# from transformer_lens.hook_points import (
    # HookPoint,
# )  # Hooking utilities
from transformer_lens import HookedTransformer
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import datasets

torch.manual_seed(42)

<torch._C.Generator at 0x722d51384f90>

Definimos o modelo Pythia usando a bilbioteca TransformerLens, que permite ter acesso às ativações internas de diversos modelos.

In [2]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m-deduped", device='cuda:0')

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


Podemos então fazera inferência e visualizar todas as camadas acessíveis através da variável cache

In [3]:
text = ["Hello, world!!!!!!!!!!!!!!, Hello, world!!!!!!!!!!!!!! Hello, world!!!!!!!!!!!!!!", "Hello,"]
tokens = model.to_tokens(text)
print(tokens)

logits, cache = model.run_with_cache(tokens)

print("Cache keys:", cache.keys())

tensor([[    0, 12092,    13,  1533, 45939, 18963,  4672,    13, 24387,    13,
          1533, 45939, 18963,  4672, 24387,    13,  1533, 45939, 18963,  4672],
        [    0, 12092,    13,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]],
       device='cuda:0')
Cache keys: dict_keys(['hook_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook

Podemos utilizar a biblioteca datasets do HuggingFace para importar um dataset pronto. Neste caso utilizamos o wikitext por ter um domínio bem definido e ser pequeno o suficiente para rodarmos localmente.

In [3]:
dataset = datasets.load_dataset("Salesforce/wikitext", split="train[:1801350]", ignore_verifications=True)

Found cached dataset parquet (/home/flfp/.cache/huggingface/datasets/Salesforce___parquet/wikitext-103-raw-v1-7bb180478b704b56/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


Podemos então gerar um dicionário com as ativações da residual stream de todas as camadas dessa forma:

In [4]:
layers = list(range(1, 2))
activations = {}
for layer in layers:
    activations[layer] = []

def create_batch(batch):
    texts = [item["text"] for item in batch]
    tokens = model.to_tokens(texts)
    return tokens

dataloader = DataLoader(dataset, batch_size=4, shuffle=False, collate_fn=create_batch)


for text in tqdm(dataloader):
    text = text.to('cuda:0')
    _, cache = model.run_with_cache(text, remove_batch_dim=False)

    for layer in layers:
        key = f"blocks.{layer}.hook_resid_post"
        print(cache[key].shape)
        activations[layer].append(cache[key])

    # algum leak aqui, memória sobe sem parar

    # nem isso funciona :/
    del text
    del cache
    del _


activations


  0%|          | 6/450338 [00:00<2:32:03, 49.36it/s]

torch.Size([4, 178, 512])
torch.Size([4, 119, 512])
torch.Size([4, 333, 512])
torch.Size([4, 230, 512])
torch.Size([4, 257, 512])
torch.Size([4, 337, 512])
torch.Size([4, 355, 512])
torch.Size([4, 301, 512])
torch.Size([4, 80, 512])


  0%|          | 17/450338 [00:00<2:28:20, 50.59it/s]

torch.Size([4, 164, 512])
torch.Size([4, 84, 512])
torch.Size([4, 277, 512])
torch.Size([4, 230, 512])
torch.Size([4, 114, 512])
torch.Size([4, 239, 512])
torch.Size([4, 303, 512])
torch.Size([4, 196, 512])
torch.Size([4, 82, 512])


  0%|          | 33/450338 [00:00<2:04:05, 60.48it/s]

torch.Size([4, 90, 512])
torch.Size([4, 15, 512])
torch.Size([4, 17, 512])
torch.Size([4, 229, 512])
torch.Size([4, 189, 512])
torch.Size([4, 87, 512])
torch.Size([4, 14, 512])
torch.Size([4, 19, 512])
torch.Size([4, 23, 512])
torch.Size([4, 135, 512])
torch.Size([4, 42, 512])
torch.Size([4, 172, 512])
torch.Size([4, 123, 512])
torch.Size([4, 195, 512])
torch.Size([4, 137, 512])
torch.Size([4, 96, 512])
torch.Size([4, 99, 512])
torch.Size([4, 10, 512])
torch.Size([4, 96, 512])


  0%|          | 50/450338 [00:00<1:46:08, 70.70it/s]

torch.Size([4, 94, 512])
torch.Size([4, 104, 512])
torch.Size([4, 99, 512])
torch.Size([4, 210, 512])
torch.Size([4, 156, 512])
torch.Size([4, 99, 512])
torch.Size([4, 73, 512])
torch.Size([4, 96, 512])
torch.Size([4, 252, 512])
torch.Size([4, 273, 512])
torch.Size([4, 7, 512])
torch.Size([4, 15, 512])
torch.Size([4, 18, 512])
torch.Size([4, 20, 512])
torch.Size([4, 14, 512])
torch.Size([4, 24, 512])
torch.Size([4, 15, 512])


  0%|          | 71/450338 [00:01<1:27:08, 86.12it/s]

torch.Size([4, 22, 512])
torch.Size([4, 19, 512])
torch.Size([4, 20, 512])
torch.Size([4, 23, 512])
torch.Size([4, 14, 512])
torch.Size([4, 14, 512])
torch.Size([4, 18, 512])
torch.Size([4, 24, 512])
torch.Size([4, 14, 512])
torch.Size([4, 18, 512])
torch.Size([4, 18, 512])
torch.Size([4, 26, 512])
torch.Size([4, 32, 512])
torch.Size([4, 27, 512])
torch.Size([4, 146, 512])
torch.Size([4, 253, 512])
torch.Size([4, 82, 512])
torch.Size([4, 115, 512])
torch.Size([4, 159, 512])
torch.Size([4, 195, 512])


  0%|          | 80/450338 [00:01<1:31:42, 81.83it/s]

torch.Size([4, 155, 512])
torch.Size([4, 210, 512])
torch.Size([4, 137, 512])
torch.Size([4, 151, 512])
torch.Size([4, 158, 512])
torch.Size([4, 126, 512])
torch.Size([4, 92, 512])
torch.Size([4, 252, 512])
torch.Size([4, 11, 512])
torch.Size([4, 440, 512])
torch.Size([4, 379, 512])
torch.Size([4, 380, 512])


  0%|          | 102/450338 [00:01<1:35:52, 78.26it/s]

torch.Size([4, 224, 512])
torch.Size([4, 111, 512])
torch.Size([4, 73, 512])
torch.Size([4, 26, 512])
torch.Size([4, 13, 512])
torch.Size([4, 11, 512])
torch.Size([4, 15, 512])
torch.Size([4, 99, 512])
torch.Size([4, 7, 512])
torch.Size([4, 13, 512])
torch.Size([4, 11, 512])
torch.Size([4, 10, 512])
torch.Size([4, 21, 512])
torch.Size([4, 36, 512])
torch.Size([4, 259, 512])
torch.Size([4, 254, 512])
torch.Size([4, 201, 512])
torch.Size([4, 237, 512])


  0%|          | 111/450338 [00:01<1:54:07, 65.75it/s]

torch.Size([4, 395, 512])
torch.Size([4, 312, 512])
torch.Size([4, 268, 512])
torch.Size([4, 237, 512])
torch.Size([4, 189, 512])
torch.Size([4, 276, 512])
torch.Size([4, 131, 512])
torch.Size([4, 235, 512])
torch.Size([4, 165, 512])
torch.Size([4, 21, 512])
torch.Size([4, 19, 512])
torch.Size([4, 8, 512])


  0%|          | 119/450338 [00:01<1:53:10, 66.30it/s]

torch.Size([4, 129, 512])
torch.Size([4, 401, 512])
torch.Size([4, 180, 512])
torch.Size([4, 8, 512])
torch.Size([4, 165, 512])
torch.Size([4, 339, 512])
torch.Size([4, 171, 512])
torch.Size([4, 331, 512])
torch.Size([4, 483, 512])
torch.Size([4, 361, 512])


  0%|          | 134/450338 [00:02<2:07:58, 58.63it/s]

torch.Size([4, 228, 512])
torch.Size([4, 170, 512])
torch.Size([4, 210, 512])
torch.Size([4, 168, 512])
torch.Size([4, 168, 512])
torch.Size([4, 256, 512])
torch.Size([4, 173, 512])
torch.Size([4, 224, 512])
torch.Size([4, 252, 512])
torch.Size([4, 133, 512])
torch.Size([4, 205, 512])
torch.Size([4, 10, 512])
torch.Size([4, 7, 512])
torch.Size([4, 151, 512])


  0%|          | 150/450338 [00:02<1:53:18, 66.21it/s]

torch.Size([4, 119, 512])
torch.Size([4, 94, 512])
torch.Size([4, 165, 512])
torch.Size([4, 163, 512])
torch.Size([4, 194, 512])
torch.Size([4, 162, 512])
torch.Size([4, 131, 512])
torch.Size([4, 43, 512])
torch.Size([4, 167, 512])
torch.Size([4, 12, 512])
torch.Size([4, 9, 512])
torch.Size([4, 23, 512])
torch.Size([4, 12, 512])
torch.Size([4, 221, 512])
torch.Size([4, 336, 512])
torch.Size([4, 212, 512])


  0%|          | 166/450338 [00:02<1:55:01, 65.23it/s]

torch.Size([4, 340, 512])
torch.Size([4, 122, 512])
torch.Size([4, 149, 512])
torch.Size([4, 213, 512])
torch.Size([4, 154, 512])
torch.Size([4, 190, 512])
torch.Size([4, 166, 512])
torch.Size([4, 148, 512])
torch.Size([4, 323, 512])
torch.Size([4, 215, 512])
torch.Size([4, 223, 512])
torch.Size([4, 273, 512])


  0%|          | 173/450338 [00:02<2:01:01, 62.00it/s]

torch.Size([4, 276, 512])
torch.Size([4, 150, 512])
torch.Size([4, 145, 512])
torch.Size([4, 258, 512])
torch.Size([4, 194, 512])
torch.Size([4, 370, 512])
torch.Size([4, 272, 512])
torch.Size([4, 214, 512])
torch.Size([4, 290, 512])
torch.Size([4, 233, 512])


  0%|          | 187/450338 [00:02<2:10:41, 57.41it/s]

torch.Size([4, 197, 512])
torch.Size([4, 179, 512])
torch.Size([4, 179, 512])
torch.Size([4, 185, 512])
torch.Size([4, 77, 512])
torch.Size([4, 311, 512])
torch.Size([4, 113, 512])
torch.Size([4, 314, 512])
torch.Size([4, 284, 512])
torch.Size([4, 109, 512])
torch.Size([4, 100, 512])
torch.Size([4, 262, 512])


  0%|          | 200/450338 [00:03<2:05:12, 59.92it/s]

torch.Size([4, 117, 512])
torch.Size([4, 316, 512])
torch.Size([4, 115, 512])
torch.Size([4, 282, 512])
torch.Size([4, 376, 512])
torch.Size([4, 152, 512])
torch.Size([4, 8, 512])
torch.Size([4, 11, 512])
torch.Size([4, 12, 512])
torch.Size([4, 10, 512])
torch.Size([4, 238, 512])
torch.Size([4, 109, 512])
torch.Size([4, 159, 512])
torch.Size([4, 211, 512])


  0%|          | 215/450338 [00:03<2:01:44, 61.62it/s]

torch.Size([4, 87, 512])
torch.Size([4, 234, 512])
torch.Size([4, 313, 512])
torch.Size([4, 189, 512])
torch.Size([4, 136, 512])
torch.Size([4, 217, 512])
torch.Size([4, 135, 512])
torch.Size([4, 157, 512])
torch.Size([4, 222, 512])
torch.Size([4, 265, 512])
torch.Size([4, 131, 512])
torch.Size([4, 173, 512])
torch.Size([4, 125, 512])


  0%|          | 230/450338 [00:03<1:56:51, 64.20it/s]

torch.Size([4, 245, 512])
torch.Size([4, 325, 512])
torch.Size([4, 191, 512])
torch.Size([4, 306, 512])
torch.Size([4, 85, 512])
torch.Size([4, 124, 512])
torch.Size([4, 132, 512])
torch.Size([4, 158, 512])
torch.Size([4, 158, 512])
torch.Size([4, 226, 512])
torch.Size([4, 206, 512])
torch.Size([4, 122, 512])
torch.Size([4, 142, 512])


  0%|          | 237/450338 [00:03<2:06:05, 59.49it/s]

torch.Size([4, 249, 512])
torch.Size([4, 117, 512])
torch.Size([4, 191, 512])
torch.Size([4, 359, 512])
torch.Size([4, 179, 512])
torch.Size([4, 385, 512])
torch.Size([4, 211, 512])
torch.Size([4, 253, 512])
torch.Size([4, 182, 512])
torch.Size([4, 145, 512])
torch.Size([4, 183, 512])


  0%|          | 251/450338 [00:03<2:04:58, 60.02it/s]

torch.Size([4, 212, 512])
torch.Size([4, 46, 512])
torch.Size([4, 272, 512])
torch.Size([4, 128, 512])
torch.Size([4, 80, 512])
torch.Size([4, 160, 512])
torch.Size([4, 158, 512])
torch.Size([4, 361, 512])
torch.Size([4, 135, 512])
torch.Size([4, 72, 512])
torch.Size([4, 202, 512])
torch.Size([4, 330, 512])
torch.Size([4, 235, 512])


  0%|          | 267/450338 [00:04<1:54:04, 65.76it/s]

torch.Size([4, 159, 512])
torch.Size([4, 293, 512])
torch.Size([4, 151, 512])
torch.Size([4, 304, 512])
torch.Size([4, 66, 512])
torch.Size([4, 14, 512])
torch.Size([4, 169, 512])
torch.Size([4, 84, 512])
torch.Size([4, 108, 512])
torch.Size([4, 41, 512])
torch.Size([4, 38, 512])
torch.Size([4, 82, 512])
torch.Size([4, 108, 512])
torch.Size([4, 78, 512])
torch.Size([4, 60, 512])
torch.Size([4, 95, 512])
torch.Size([4, 102, 512])


  0%|          | 285/450338 [00:04<1:44:08, 72.02it/s]

torch.Size([4, 46, 512])
torch.Size([4, 61, 512])
torch.Size([4, 164, 512])
torch.Size([4, 146, 512])
torch.Size([4, 140, 512])
torch.Size([4, 60, 512])
torch.Size([4, 132, 512])
torch.Size([4, 93, 512])
torch.Size([4, 127, 512])
torch.Size([4, 309, 512])
torch.Size([4, 289, 512])
torch.Size([4, 105, 512])
torch.Size([4, 167, 512])
torch.Size([4, 179, 512])
torch.Size([4, 49, 512])
torch.Size([4, 47, 512])


  0%|          | 294/450338 [00:04<1:41:37, 73.80it/s]

torch.Size([4, 38, 512])
torch.Size([4, 220, 512])
torch.Size([4, 130, 512])
torch.Size([4, 202, 512])
torch.Size([4, 48, 512])
torch.Size([4, 192, 512])
torch.Size([4, 90, 512])
torch.Size([4, 256, 512])
torch.Size([4, 8, 512])
torch.Size([4, 343, 512])
torch.Size([4, 244, 512])
torch.Size([4, 112, 512])
torch.Size([4, 438, 512])


  0%|          | 309/450338 [00:04<1:52:58, 66.39it/s]

torch.Size([4, 65, 512])
torch.Size([4, 188, 512])
torch.Size([4, 174, 512])
torch.Size([4, 254, 512])
torch.Size([4, 134, 512])
torch.Size([4, 115, 512])
torch.Size([4, 161, 512])
torch.Size([4, 215, 512])
torch.Size([4, 264, 512])
torch.Size([4, 44, 512])
torch.Size([4, 143, 512])
torch.Size([4, 201, 512])
torch.Size([4, 135, 512])
torch.Size([4, 249, 512])


  0%|          | 323/450338 [00:05<2:14:49, 55.63it/s]

torch.Size([4, 503, 512])
torch.Size([4, 307, 512])
torch.Size([4, 429, 512])
torch.Size([4, 392, 512])
torch.Size([4, 7, 512])
torch.Size([4, 234, 512])
torch.Size([4, 141, 512])
torch.Size([4, 145, 512])
torch.Size([4, 7, 512])
torch.Size([4, 192, 512])
torch.Size([4, 222, 512])


  0%|          | 330/450338 [00:05<2:09:37, 57.86it/s]

torch.Size([4, 181, 512])
torch.Size([4, 336, 512])
torch.Size([4, 134, 512])
torch.Size([4, 132, 512])
torch.Size([4, 187, 512])
torch.Size([4, 112, 512])
torch.Size([4, 256, 512])
torch.Size([4, 133, 512])
torch.Size([4, 57, 512])
torch.Size([4, 67, 512])


  0%|          | 344/450338 [00:05<2:09:16, 58.02it/s]

torch.Size([4, 651, 512])
torch.Size([4, 184, 512])
torch.Size([4, 107, 512])
torch.Size([4, 46, 512])
torch.Size([4, 114, 512])
torch.Size([4, 121, 512])
torch.Size([4, 9, 512])
torch.Size([4, 9, 512])
torch.Size([4, 10, 512])
torch.Size([4, 9, 512])
torch.Size([4, 198, 512])
torch.Size([4, 146, 512])
torch.Size([4, 104, 512])
torch.Size([4, 186, 512])
torch.Size([4, 173, 512])
torch.Size([4, 130, 512])
torch.Size([4, 83, 512])


  0%|          | 360/450338 [00:05<1:55:24, 64.99it/s]

torch.Size([4, 196, 512])
torch.Size([4, 81, 512])
torch.Size([4, 215, 512])
torch.Size([4, 189, 512])
torch.Size([4, 198, 512])
torch.Size([4, 247, 512])
torch.Size([4, 48, 512])
torch.Size([4, 263, 512])
torch.Size([4, 407, 512])
torch.Size([4, 421, 512])


  0%|          | 367/450338 [00:05<2:20:32, 53.36it/s]

torch.Size([4, 374, 512])
torch.Size([4, 338, 512])
torch.Size([4, 259, 512])
torch.Size([4, 105, 512])
torch.Size([4, 281, 512])
torch.Size([4, 388, 512])


  0%|          | 380/450338 [00:06<2:31:33, 49.48it/s]

torch.Size([4, 198, 512])
torch.Size([4, 101, 512])
torch.Size([4, 214, 512])
torch.Size([4, 73, 512])
torch.Size([4, 161, 512])
torch.Size([4, 108, 512])
torch.Size([4, 190, 512])
torch.Size([4, 222, 512])
torch.Size([4, 243, 512])
torch.Size([4, 245, 512])
torch.Size([4, 231, 512])
torch.Size([4, 294, 512])
torch.Size([4, 190, 512])


  0%|          | 396/450338 [00:06<2:04:24, 60.28it/s]

torch.Size([4, 155, 512])
torch.Size([4, 106, 512])
torch.Size([4, 120, 512])
torch.Size([4, 168, 512])
torch.Size([4, 178, 512])
torch.Size([4, 205, 512])
torch.Size([4, 93, 512])
torch.Size([4, 45, 512])
torch.Size([4, 9, 512])
torch.Size([4, 159, 512])
torch.Size([4, 182, 512])
torch.Size([4, 141, 512])
torch.Size([4, 46, 512])
torch.Size([4, 252, 512])
torch.Size([4, 8, 512])
torch.Size([4, 88, 512])
torch.Size([4, 137, 512])


  0%|          | 406/450338 [00:06<1:48:11, 69.32it/s]

torch.Size([4, 112, 512])
torch.Size([4, 97, 512])
torch.Size([4, 43, 512])
torch.Size([4, 118, 512])
torch.Size([4, 181, 512])
torch.Size([4, 90, 512])
torch.Size([4, 11, 512])
torch.Size([4, 148, 512])
torch.Size([4, 152, 512])
torch.Size([4, 143, 512])
torch.Size([4, 289, 512])
torch.Size([4, 292, 512])
torch.Size([4, 54, 512])
torch.Size([4, 380, 512])


  0%|          | 422/450338 [00:06<1:50:55, 67.60it/s]

torch.Size([4, 128, 512])
torch.Size([4, 159, 512])
torch.Size([4, 124, 512])
torch.Size([4, 124, 512])
torch.Size([4, 172, 512])
torch.Size([4, 110, 512])
torch.Size([4, 83, 512])
torch.Size([4, 248, 512])
torch.Size([4, 218, 512])
torch.Size([4, 301, 512])
torch.Size([4, 9, 512])
torch.Size([4, 220, 512])
torch.Size([4, 216, 512])
torch.Size([4, 210, 512])


  0%|          | 436/450338 [00:06<1:55:11, 65.10it/s]

torch.Size([4, 247, 512])
torch.Size([4, 6, 512])
torch.Size([4, 253, 512])
torch.Size([4, 59, 512])
torch.Size([4, 77, 512])
torch.Size([4, 127, 512])
torch.Size([4, 374, 512])
torch.Size([4, 16, 512])
torch.Size([4, 212, 512])
torch.Size([4, 155, 512])
torch.Size([4, 164, 512])
torch.Size([4, 328, 512])
torch.Size([4, 239, 512])


  0%|          | 450/450338 [00:07<2:10:17, 57.55it/s]

torch.Size([4, 201, 512])
torch.Size([4, 175, 512])
torch.Size([4, 108, 512])
torch.Size([4, 248, 512])
torch.Size([4, 247, 512])
torch.Size([4, 398, 512])
torch.Size([4, 195, 512])
torch.Size([4, 329, 512])
torch.Size([4, 221, 512])
torch.Size([4, 115, 512])
torch.Size([4, 153, 512])


  0%|          | 464/450338 [00:07<2:00:36, 62.16it/s]

torch.Size([4, 141, 512])
torch.Size([4, 138, 512])
torch.Size([4, 277, 512])
torch.Size([4, 256, 512])
torch.Size([4, 10, 512])
torch.Size([4, 10, 512])
torch.Size([4, 236, 512])
torch.Size([4, 185, 512])
torch.Size([4, 235, 512])
torch.Size([4, 173, 512])
torch.Size([4, 108, 512])
torch.Size([4, 162, 512])
torch.Size([4, 157, 512])
torch.Size([4, 9, 512])
torch.Size([4, 100, 512])


  0%|          | 472/450338 [00:07<1:59:29, 62.75it/s]

torch.Size([4, 92, 512])
torch.Size([4, 229, 512])
torch.Size([4, 10, 512])
torch.Size([4, 170, 512])
torch.Size([4, 296, 512])
torch.Size([4, 327, 512])
torch.Size([4, 153, 512])
torch.Size([4, 424, 512])
torch.Size([4, 288, 512])
torch.Size([4, 127, 512])
torch.Size([4, 343, 512])


  0%|          | 485/450338 [00:07<2:25:13, 51.63it/s]

torch.Size([4, 61, 512])
torch.Size([4, 241, 512])
torch.Size([4, 285, 512])
torch.Size([4, 8, 512])
torch.Size([4, 245, 512])
torch.Size([4, 426, 512])
torch.Size([4, 226, 512])
torch.Size([4, 494, 512])
torch.Size([4, 291, 512])


  0%|          | 491/450338 [00:07<2:26:08, 51.31it/s]

torch.Size([4, 353, 512])
torch.Size([4, 156, 512])
torch.Size([4, 262, 512])
torch.Size([4, 148, 512])
torch.Size([4, 185, 512])
torch.Size([4, 72, 512])
torch.Size([4, 142, 512])
torch.Size([4, 109, 512])
torch.Size([4, 191, 512])
torch.Size([4, 191, 512])
torch.Size([4, 162, 512])
torch.Size([4, 8, 512])
torch.Size([4, 168, 512])


  0%|          | 505/450338 [00:08<2:17:27, 54.54it/s]

torch.Size([4, 156, 512])
torch.Size([4, 224, 512])
torch.Size([4, 211, 512])
torch.Size([4, 288, 512])
torch.Size([4, 340, 512])
torch.Size([4, 312, 512])
torch.Size([4, 151, 512])
torch.Size([4, 172, 512])
torch.Size([4, 227, 512])
torch.Size([4, 144, 512])
torch.Size([4, 222, 512])
torch.Size([4, 9, 512])


  0%|          | 519/450338 [00:08<2:04:00, 60.46it/s]

torch.Size([4, 194, 512])
torch.Size([4, 128, 512])
torch.Size([4, 7, 512])
torch.Size([4, 258, 512])
torch.Size([4, 127, 512])
torch.Size([4, 216, 512])
torch.Size([4, 163, 512])
torch.Size([4, 216, 512])
torch.Size([4, 238, 512])
torch.Size([4, 293, 512])
torch.Size([4, 147, 512])
torch.Size([4, 10, 512])
torch.Size([4, 151, 512])


  0%|          | 533/450338 [00:08<2:11:28, 57.02it/s]

torch.Size([4, 174, 512])
torch.Size([4, 165, 512])
torch.Size([4, 155, 512])
torch.Size([4, 254, 512])
torch.Size([4, 211, 512])
torch.Size([4, 12, 512])
torch.Size([4, 346, 512])
torch.Size([4, 351, 512])
torch.Size([4, 236, 512])
torch.Size([4, 211, 512])
torch.Size([4, 14, 512])
torch.Size([4, 145, 512])


  0%|          | 545/450338 [00:08<2:09:04, 58.08it/s]

torch.Size([4, 289, 512])
torch.Size([4, 255, 512])
torch.Size([4, 232, 512])
torch.Size([4, 52, 512])
torch.Size([4, 210, 512])
torch.Size([4, 173, 512])
torch.Size([4, 256, 512])
torch.Size([4, 276, 512])
torch.Size([4, 162, 512])
torch.Size([4, 200, 512])
torch.Size([4, 60, 512])
torch.Size([4, 9, 512])
torch.Size([4, 101, 512])


  0%|          | 564/450338 [00:09<1:39:15, 75.52it/s]

torch.Size([4, 48, 512])
torch.Size([4, 44, 512])
torch.Size([4, 194, 512])
torch.Size([4, 149, 512])
torch.Size([4, 145, 512])
torch.Size([4, 214, 512])
torch.Size([4, 192, 512])
torch.Size([4, 135, 512])
torch.Size([4, 126, 512])
torch.Size([4, 11, 512])
torch.Size([4, 10, 512])
torch.Size([4, 11, 512])
torch.Size([4, 10, 512])
torch.Size([4, 11, 512])
torch.Size([4, 14, 512])
torch.Size([4, 10, 512])
torch.Size([4, 26, 512])
torch.Size([4, 14, 512])
torch.Size([4, 13, 512])
torch.Size([4, 8, 512])


  0%|          | 573/450338 [00:09<1:39:20, 75.46it/s]

torch.Size([4, 235, 512])
torch.Size([4, 241, 512])
torch.Size([4, 241, 512])
torch.Size([4, 249, 512])
torch.Size([4, 261, 512])
torch.Size([4, 160, 512])
torch.Size([4, 341, 512])
torch.Size([4, 358, 512])
torch.Size([4, 328, 512])


  0%|          | 581/450338 [00:09<2:14:53, 55.57it/s]

torch.Size([4, 336, 512])
torch.Size([4, 348, 512])
torch.Size([4, 613, 512])
torch.Size([4, 474, 512])
torch.Size([4, 26, 512])
torch.Size([4, 31, 512])
torch.Size([4, 13, 512])
torch.Size([4, 8, 512])
torch.Size([4, 8, 512])
torch.Size([4, 10, 512])
torch.Size([4, 150, 512])


  0%|          | 596/450338 [00:09<2:05:17, 59.83it/s]

torch.Size([4, 311, 512])
torch.Size([4, 7, 512])
torch.Size([4, 196, 512])
torch.Size([4, 161, 512])
torch.Size([4, 225, 512])
torch.Size([4, 302, 512])
torch.Size([4, 7, 512])
torch.Size([4, 192, 512])
torch.Size([4, 161, 512])
torch.Size([4, 142, 512])
torch.Size([4, 22, 512])
torch.Size([4, 25, 512])
torch.Size([4, 20, 512])
torch.Size([4, 29, 512])
torch.Size([4, 27, 512])


  0%|          | 614/450338 [00:09<1:57:03, 64.03it/s]

torch.Size([4, 7, 512])
torch.Size([4, 214, 512])
torch.Size([4, 196, 512])
torch.Size([4, 295, 512])
torch.Size([4, 209, 512])
torch.Size([4, 239, 512])
torch.Size([4, 52, 512])
torch.Size([4, 317, 512])
torch.Size([4, 117, 512])
torch.Size([4, 236, 512])
torch.Size([4, 10, 512])
torch.Size([4, 267, 512])


  0%|          | 629/450338 [00:10<1:54:11, 65.64it/s]

torch.Size([4, 199, 512])
torch.Size([4, 251, 512])
torch.Size([4, 123, 512])
torch.Size([4, 276, 512])
torch.Size([4, 270, 512])
torch.Size([4, 138, 512])
torch.Size([4, 105, 512])
torch.Size([4, 211, 512])
torch.Size([4, 115, 512])
torch.Size([4, 7, 512])
torch.Size([4, 157, 512])
torch.Size([4, 166, 512])
torch.Size([4, 116, 512])
torch.Size([4, 180, 512])


  0%|          | 648/450338 [00:10<1:37:03, 77.22it/s]

torch.Size([4, 133, 512])
torch.Size([4, 144, 512])
torch.Size([4, 105, 512])
torch.Size([4, 105, 512])
torch.Size([4, 95, 512])
torch.Size([4, 78, 512])
torch.Size([4, 8, 512])
torch.Size([4, 8, 512])
torch.Size([4, 12, 512])
torch.Size([4, 12, 512])
torch.Size([4, 187, 512])
torch.Size([4, 94, 512])
torch.Size([4, 169, 512])
torch.Size([4, 46, 512])
torch.Size([4, 155, 512])
torch.Size([4, 216, 512])
torch.Size([4, 93, 512])
torch.Size([4, 69, 512])
torch.Size([4, 104, 512])


  0%|          | 656/450338 [00:10<1:46:41, 70.25it/s]

torch.Size([4, 207, 512])
torch.Size([4, 132, 512])
torch.Size([4, 17, 512])
torch.Size([4, 81, 512])
torch.Size([4, 208, 512])
torch.Size([4, 181, 512])
torch.Size([4, 223, 512])
torch.Size([4, 109, 512])
torch.Size([4, 72, 512])
torch.Size([4, 135, 512])
torch.Size([4, 125, 512])
torch.Size([4, 185, 512])
torch.Size([4, 123, 512])
torch.Size([4, 179, 512])


  0%|          | 673/450338 [00:10<1:50:07, 68.05it/s]

torch.Size([4, 113, 512])
torch.Size([4, 303, 512])
torch.Size([4, 124, 512])
torch.Size([4, 19, 512])
torch.Size([4, 118, 512])
torch.Size([4, 145, 512])
torch.Size([4, 221, 512])
torch.Size([4, 378, 512])
torch.Size([4, 195, 512])
torch.Size([4, 254, 512])
torch.Size([4, 168, 512])
torch.Size([4, 176, 512])


  0%|          | 686/450338 [00:11<2:13:02, 56.33it/s]

torch.Size([4, 206, 512])
torch.Size([4, 160, 512])
torch.Size([4, 157, 512])
torch.Size([4, 254, 512])
torch.Size([4, 149, 512])
torch.Size([4, 211, 512])
torch.Size([4, 183, 512])
torch.Size([4, 176, 512])
torch.Size([4, 305, 512])
torch.Size([4, 148, 512])
torch.Size([4, 160, 512])
torch.Size([4, 264, 512])


  0%|          | 692/450338 [00:11<2:25:05, 51.65it/s]

torch.Size([4, 299, 512])
torch.Size([4, 248, 512])
torch.Size([4, 260, 512])
torch.Size([4, 344, 512])
torch.Size([4, 266, 512])
torch.Size([4, 282, 512])
torch.Size([4, 268, 512])
torch.Size([4, 245, 512])
torch.Size([4, 183, 512])


  0%|          | 707/450338 [00:11<2:05:42, 59.61it/s]

torch.Size([4, 226, 512])
torch.Size([4, 236, 512])
torch.Size([4, 281, 512])
torch.Size([4, 331, 512])
torch.Size([4, 10, 512])
torch.Size([4, 9, 512])
torch.Size([4, 9, 512])
torch.Size([4, 12, 512])
torch.Size([4, 8, 512])
torch.Size([4, 8, 512])
torch.Size([4, 131, 512])
torch.Size([4, 103, 512])
torch.Size([4, 217, 512])
torch.Size([4, 9, 512])
torch.Size([4, 282, 512])


  0%|          | 721/450338 [00:11<2:12:18, 56.64it/s]

torch.Size([4, 223, 512])
torch.Size([4, 231, 512])
torch.Size([4, 231, 512])
torch.Size([4, 306, 512])
torch.Size([4, 275, 512])
torch.Size([4, 155, 512])
torch.Size([4, 244, 512])
torch.Size([4, 183, 512])
torch.Size([4, 266, 512])
torch.Size([4, 218, 512])
torch.Size([4, 154, 512])


  0%|          | 742/450338 [00:11<1:38:10, 76.32it/s]

torch.Size([4, 18, 512])
torch.Size([4, 13, 512])
torch.Size([4, 10, 512])
torch.Size([4, 11, 512])
torch.Size([4, 17, 512])
torch.Size([4, 8, 512])
torch.Size([4, 23, 512])
torch.Size([4, 19, 512])
torch.Size([4, 10, 512])
torch.Size([4, 11, 512])
torch.Size([4, 26, 512])
torch.Size([4, 25, 512])
torch.Size([4, 153, 512])
torch.Size([4, 97, 512])
torch.Size([4, 90, 512])
torch.Size([4, 253, 512])
torch.Size([4, 144, 512])
torch.Size([4, 118, 512])
torch.Size([4, 35, 512])
torch.Size([4, 13, 512])
torch.Size([4, 25, 512])


  0%|          | 754/450338 [00:11<1:25:17, 87.86it/s]

torch.Size([4, 57, 512])
torch.Size([4, 17, 512])
torch.Size([4, 28, 512])
torch.Size([4, 19, 512])
torch.Size([4, 66, 512])
torch.Size([4, 10, 512])
torch.Size([4, 14, 512])
torch.Size([4, 15, 512])
torch.Size([4, 11, 512])
torch.Size([4, 11, 512])
torch.Size([4, 21, 512])
torch.Size([4, 10, 512])
torch.Size([4, 11, 512])
torch.Size([4, 26, 512])
torch.Size([4, 46, 512])
torch.Size([4, 21, 512])
torch.Size([4, 49, 512])
torch.Size([4, 147, 512])
torch.Size([4, 136, 512])
torch.Size([4, 160, 512])
torch.Size([4, 68, 512])
torch.Size([4, 72, 512])
torch.Size([4, 174, 512])


  0%|          | 775/450338 [00:12<1:21:29, 91.95it/s]

torch.Size([4, 7, 512])
torch.Size([4, 77, 512])
torch.Size([4, 160, 512])
torch.Size([4, 55, 512])
torch.Size([4, 35, 512])
torch.Size([4, 105, 512])
torch.Size([4, 67, 512])
torch.Size([4, 34, 512])
torch.Size([4, 31, 512])
torch.Size([4, 168, 512])
torch.Size([4, 137, 512])
torch.Size([4, 104, 512])
torch.Size([4, 142, 512])
torch.Size([4, 174, 512])
torch.Size([4, 132, 512])
torch.Size([4, 187, 512])
torch.Size([4, 42, 512])
torch.Size([4, 257, 512])


  0%|          | 794/450338 [00:12<1:39:45, 75.10it/s]

torch.Size([4, 261, 512])
torch.Size([4, 14, 512])
torch.Size([4, 222, 512])
torch.Size([4, 178, 512])
torch.Size([4, 168, 512])
torch.Size([4, 287, 512])
torch.Size([4, 7, 512])
torch.Size([4, 242, 512])
torch.Size([4, 236, 512])
torch.Size([4, 211, 512])
torch.Size([4, 327, 512])
torch.Size([4, 300, 512])


  0%|          | 802/450338 [00:12<1:55:35, 64.81it/s]

torch.Size([4, 288, 512])
torch.Size([4, 294, 512])
torch.Size([4, 222, 512])
torch.Size([4, 242, 512])
torch.Size([4, 235, 512])
torch.Size([4, 192, 512])
torch.Size([4, 16, 512])
torch.Size([4, 294, 512])
torch.Size([4, 203, 512])
torch.Size([4, 235, 512])
torch.Size([4, 363, 512])


  0%|          | 816/450338 [00:12<2:05:32, 59.68it/s]

torch.Size([4, 246, 512])
torch.Size([4, 183, 512])
torch.Size([4, 176, 512])
torch.Size([4, 198, 512])
torch.Size([4, 245, 512])
torch.Size([4, 153, 512])
torch.Size([4, 296, 512])
torch.Size([4, 185, 512])
torch.Size([4, 195, 512])
torch.Size([4, 96, 512])
torch.Size([4, 173, 512])
torch.Size([4, 221, 512])


  0%|          | 829/450338 [00:13<2:08:17, 58.40it/s]

torch.Size([4, 263, 512])
torch.Size([4, 225, 512])
torch.Size([4, 300, 512])
torch.Size([4, 203, 512])
torch.Size([4, 190, 512])
torch.Size([4, 68, 512])
torch.Size([4, 242, 512])
torch.Size([4, 268, 512])
torch.Size([4, 185, 512])
torch.Size([4, 210, 512])
torch.Size([4, 257, 512])


  0%|          | 835/450338 [00:13<2:18:24, 54.13it/s]

torch.Size([4, 366, 512])
torch.Size([4, 227, 512])
torch.Size([4, 353, 512])
torch.Size([4, 286, 512])
torch.Size([4, 56, 512])
torch.Size([4, 290, 512])
torch.Size([4, 113, 512])
torch.Size([4, 76, 512])
torch.Size([4, 124, 512])
torch.Size([4, 189, 512])
torch.Size([4, 7, 512])
torch.Size([4, 158, 512])


  0%|          | 851/450338 [00:13<1:59:27, 62.71it/s]

torch.Size([4, 10, 512])
torch.Size([4, 178, 512])
torch.Size([4, 171, 512])
torch.Size([4, 197, 512])
torch.Size([4, 18, 512])
torch.Size([4, 178, 512])
torch.Size([4, 246, 512])
torch.Size([4, 12, 512])
torch.Size([4, 148, 512])
torch.Size([4, 144, 512])
torch.Size([4, 89, 512])
torch.Size([4, 41, 512])
torch.Size([4, 109, 512])
torch.Size([4, 78, 512])
torch.Size([4, 252, 512])
torch.Size([4, 137, 512])


  0%|          | 866/450338 [00:13<1:59:24, 62.74it/s]

torch.Size([4, 257, 512])
torch.Size([4, 92, 512])
torch.Size([4, 9, 512])
torch.Size([4, 389, 512])
torch.Size([4, 178, 512])
torch.Size([4, 162, 512])
torch.Size([4, 8, 512])
torch.Size([4, 299, 512])
torch.Size([4, 89, 512])
torch.Size([4, 269, 512])
torch.Size([4, 138, 512])
torch.Size([4, 46, 512])


  0%|          | 874/450338 [00:13<1:52:05, 66.83it/s]

torch.Size([4, 208, 512])
torch.Size([4, 86, 512])
torch.Size([4, 196, 512])
torch.Size([4, 8, 512])
torch.Size([4, 193, 512])
torch.Size([4, 194, 512])
torch.Size([4, 191, 512])
torch.Size([4, 443, 512])
torch.Size([4, 324, 512])
torch.Size([4, 99, 512])
torch.Size([4, 348, 512])


  0%|          | 889/450338 [00:14<1:58:15, 63.34it/s]

torch.Size([4, 100, 512])
torch.Size([4, 185, 512])
torch.Size([4, 221, 512])
torch.Size([4, 49, 512])
torch.Size([4, 97, 512])
torch.Size([4, 39, 512])
torch.Size([4, 114, 512])
torch.Size([4, 299, 512])
torch.Size([4, 124, 512])
torch.Size([4, 73, 512])
torch.Size([4, 8, 512])
torch.Size([4, 62, 512])
torch.Size([4, 138, 512])
torch.Size([4, 200, 512])
torch.Size([4, 149, 512])
torch.Size([4, 8, 512])
torch.Size([4, 327, 512])


  0%|          | 905/450338 [00:14<2:13:24, 56.15it/s]

torch.Size([4, 389, 512])
torch.Size([4, 260, 512])
torch.Size([4, 11, 512])
torch.Size([4, 432, 512])
torch.Size([4, 207, 512])
torch.Size([4, 442, 512])
torch.Size([4, 219, 512])
torch.Size([4, 259, 512])


  0%|          | 922/450338 [00:14<1:51:29, 67.19it/s]

torch.Size([4, 8, 512])
torch.Size([4, 12, 512])
torch.Size([4, 11, 512])
torch.Size([4, 9, 512])
torch.Size([4, 9, 512])
torch.Size([4, 90, 512])
torch.Size([4, 81, 512])
torch.Size([4, 143, 512])
torch.Size([4, 234, 512])
torch.Size([4, 113, 512])
torch.Size([4, 137, 512])
torch.Size([4, 63, 512])
torch.Size([4, 168, 512])
torch.Size([4, 55, 512])
torch.Size([4, 180, 512])
torch.Size([4, 153, 512])
torch.Size([4, 101, 512])
torch.Size([4, 135, 512])


  0%|          | 930/450338 [00:14<1:46:31, 70.31it/s]

torch.Size([4, 193, 512])
torch.Size([4, 228, 512])
torch.Size([4, 69, 512])
torch.Size([4, 70, 512])
torch.Size([4, 38, 512])
torch.Size([4, 148, 512])
torch.Size([4, 105, 512])
torch.Size([4, 178, 512])
torch.Size([4, 140, 512])
torch.Size([4, 281, 512])
torch.Size([4, 259, 512])
torch.Size([4, 204, 512])
torch.Size([4, 184, 512])
torch.Size([4, 10, 512])


  0%|          | 938/450338 [00:14<1:49:48, 68.21it/s]

torch.Size([4, 247, 512])
torch.Size([4, 192, 512])
torch.Size([4, 425, 512])
torch.Size([4, 241, 512])
torch.Size([4, 448, 512])
torch.Size([4, 225, 512])
torch.Size([4, 258, 512])
torch.Size([4, 287, 512])


  0%|          | 953/450338 [00:15<2:13:45, 55.99it/s]

torch.Size([4, 334, 512])
torch.Size([4, 292, 512])
torch.Size([4, 9, 512])
torch.Size([4, 160, 512])
torch.Size([4, 169, 512])
torch.Size([4, 173, 512])
torch.Size([4, 191, 512])
torch.Size([4, 10, 512])
torch.Size([4, 288, 512])
torch.Size([4, 311, 512])
torch.Size([4, 125, 512])
torch.Size([4, 73, 512])


  0%|          | 970/450338 [00:15<1:52:51, 66.36it/s]

torch.Size([4, 166, 512])
torch.Size([4, 130, 512])
torch.Size([4, 23, 512])
torch.Size([4, 10, 512])
torch.Size([4, 74, 512])
torch.Size([4, 10, 512])
torch.Size([4, 96, 512])
torch.Size([4, 155, 512])
torch.Size([4, 157, 512])
torch.Size([4, 166, 512])
torch.Size([4, 180, 512])
torch.Size([4, 121, 512])
torch.Size([4, 198, 512])
torch.Size([4, 158, 512])
torch.Size([4, 23, 512])
torch.Size([4, 131, 512])
torch.Size([4, 131, 512])


  0%|          | 987/450338 [00:15<1:43:47, 72.15it/s]

torch.Size([4, 79, 512])
torch.Size([4, 107, 512])
torch.Size([4, 74, 512])
torch.Size([4, 53, 512])
torch.Size([4, 50, 512])
torch.Size([4, 220, 512])
torch.Size([4, 76, 512])
torch.Size([4, 12, 512])
torch.Size([4, 119, 512])
torch.Size([4, 190, 512])
torch.Size([4, 177, 512])
torch.Size([4, 252, 512])
torch.Size([4, 164, 512])
torch.Size([4, 197, 512])
torch.Size([4, 394, 512])


  0%|          | 995/450338 [00:15<2:01:22, 61.70it/s]

torch.Size([4, 268, 512])
torch.Size([4, 220, 512])
torch.Size([4, 344, 512])
torch.Size([4, 164, 512])
torch.Size([4, 310, 512])
torch.Size([4, 292, 512])
torch.Size([4, 183, 512])
torch.Size([4, 291, 512])
torch.Size([4, 146, 512])
torch.Size([4, 272, 512])


  0%|          | 1012/450338 [00:16<1:49:43, 68.25it/s]

torch.Size([4, 217, 512])
torch.Size([4, 60, 512])
torch.Size([4, 218, 512])
torch.Size([4, 122, 512])
torch.Size([4, 78, 512])
torch.Size([4, 117, 512])
torch.Size([4, 131, 512])
torch.Size([4, 100, 512])
torch.Size([4, 30, 512])
torch.Size([4, 53, 512])
torch.Size([4, 80, 512])
torch.Size([4, 71, 512])
torch.Size([4, 113, 512])
torch.Size([4, 70, 512])
torch.Size([4, 115, 512])
torch.Size([4, 201, 512])


  0%|          | 1029/450338 [00:16<1:42:34, 73.00it/s]

torch.Size([4, 75, 512])
torch.Size([4, 110, 512])
torch.Size([4, 41, 512])
torch.Size([4, 123, 512])
torch.Size([4, 121, 512])
torch.Size([4, 130, 512])
torch.Size([4, 184, 512])
torch.Size([4, 71, 512])
torch.Size([4, 216, 512])
torch.Size([4, 102, 512])
torch.Size([4, 27, 512])
torch.Size([4, 100, 512])
torch.Size([4, 43, 512])
torch.Size([4, 74, 512])
torch.Size([4, 56, 512])
torch.Size([4, 144, 512])
torch.Size([4, 73, 512])
torch.Size([4, 111, 512])
torch.Size([4, 82, 512])


  0%|          | 1039/450338 [00:16<1:33:47, 79.83it/s]

torch.Size([4, 33, 512])
torch.Size([4, 193, 512])
torch.Size([4, 70, 512])
torch.Size([4, 8, 512])
torch.Size([4, 151, 512])
torch.Size([4, 272, 512])
torch.Size([4, 254, 512])
torch.Size([4, 231, 512])
torch.Size([4, 184, 512])
torch.Size([4, 325, 512])
torch.Size([4, 204, 512])
torch.Size([4, 260, 512])


  0%|          | 1056/450338 [00:16<1:44:58, 71.34it/s]

torch.Size([4, 16, 512])
torch.Size([4, 24, 512])
torch.Size([4, 17, 512])
torch.Size([4, 8, 512])
torch.Size([4, 113, 512])
torch.Size([4, 336, 512])
torch.Size([4, 89, 512])
torch.Size([4, 216, 512])
torch.Size([4, 118, 512])
torch.Size([4, 300, 512])
torch.Size([4, 184, 512])
torch.Size([4, 118, 512])
torch.Size([4, 112, 512])
torch.Size([4, 15, 512])
torch.Size([4, 154, 512])
torch.Size([4, 159, 512])


  0%|          | 1072/450338 [00:16<1:43:40, 72.22it/s]

torch.Size([4, 222, 512])
torch.Size([4, 202, 512])
torch.Size([4, 171, 512])
torch.Size([4, 146, 512])
torch.Size([4, 77, 512])
torch.Size([4, 83, 512])
torch.Size([4, 151, 512])
torch.Size([4, 123, 512])
torch.Size([4, 203, 512])
torch.Size([4, 80, 512])
torch.Size([4, 345, 512])
torch.Size([4, 231, 512])
torch.Size([4, 123, 512])
torch.Size([4, 206, 512])


  0%|          | 1089/450338 [00:17<1:46:57, 70.00it/s]

torch.Size([4, 160, 512])
torch.Size([4, 105, 512])
torch.Size([4, 93, 512])
torch.Size([4, 215, 512])
torch.Size([4, 10, 512])
torch.Size([4, 101, 512])
torch.Size([4, 144, 512])
torch.Size([4, 13, 512])
torch.Size([4, 63, 512])
torch.Size([4, 77, 512])
torch.Size([4, 120, 512])
torch.Size([4, 504, 512])
torch.Size([4, 91, 512])
torch.Size([4, 198, 512])
torch.Size([4, 24, 512])
torch.Size([4, 13, 512])


  0%|          | 1107/450338 [00:17<1:36:38, 77.47it/s]

torch.Size([4, 11, 512])
torch.Size([4, 241, 512])
torch.Size([4, 83, 512])
torch.Size([4, 10, 512])
torch.Size([4, 275, 512])
torch.Size([4, 73, 512])
torch.Size([4, 248, 512])
torch.Size([4, 109, 512])
torch.Size([4, 153, 512])
torch.Size([4, 12, 512])
torch.Size([4, 125, 512])
torch.Size([4, 81, 512])
torch.Size([4, 85, 512])
torch.Size([4, 95, 512])
torch.Size([4, 85, 512])
torch.Size([4, 17, 512])
torch.Size([4, 11, 512])
torch.Size([4, 7, 512])


  0%|          | 1115/450338 [00:17<1:36:29, 77.59it/s]

torch.Size([4, 171, 512])
torch.Size([4, 130, 512])
torch.Size([4, 234, 512])
torch.Size([4, 253, 512])
torch.Size([4, 273, 512])
torch.Size([4, 254, 512])
torch.Size([4, 193, 512])
torch.Size([4, 204, 512])
torch.Size([4, 270, 512])
torch.Size([4, 247, 512])
torch.Size([4, 61, 512])
torch.Size([4, 194, 512])


  0%|          | 1131/450338 [00:17<1:45:18, 71.09it/s]

torch.Size([4, 158, 512])
torch.Size([4, 96, 512])
torch.Size([4, 170, 512])
torch.Size([4, 81, 512])
torch.Size([4, 14, 512])
torch.Size([4, 318, 512])
torch.Size([4, 146, 512])
torch.Size([4, 91, 512])
torch.Size([4, 240, 512])
torch.Size([4, 155, 512])
torch.Size([4, 79, 512])
torch.Size([4, 77, 512])
torch.Size([4, 114, 512])
torch.Size([4, 79, 512])
torch.Size([4, 88, 512])
torch.Size([4, 212, 512])


  0%|          | 1149/450338 [00:17<1:39:53, 74.95it/s]

torch.Size([4, 304, 512])
torch.Size([4, 89, 512])
torch.Size([4, 118, 512])
torch.Size([4, 191, 512])
torch.Size([4, 101, 512])
torch.Size([4, 118, 512])
torch.Size([4, 125, 512])
torch.Size([4, 241, 512])
torch.Size([4, 123, 512])
torch.Size([4, 94, 512])
torch.Size([4, 14, 512])
torch.Size([4, 184, 512])
torch.Size([4, 228, 512])
torch.Size([4, 170, 512])
torch.Size([4, 38, 512])
torch.Size([4, 178, 512])


  0%|          | 1165/450338 [00:18<1:58:55, 62.95it/s]

torch.Size([4, 95, 512])
torch.Size([4, 292, 512])
torch.Size([4, 178, 512])
torch.Size([4, 171, 512])
torch.Size([4, 180, 512])
torch.Size([4, 10, 512])
torch.Size([4, 150, 512])
torch.Size([4, 245, 512])
torch.Size([4, 86, 512])
torch.Size([4, 154, 512])
torch.Size([4, 178, 512])
torch.Size([4, 138, 512])
torch.Size([4, 331, 512])
torch.Size([4, 148, 512])


  0%|          | 1180/450338 [00:18<1:52:38, 66.46it/s]

torch.Size([4, 206, 512])
torch.Size([4, 138, 512])
torch.Size([4, 26, 512])
torch.Size([4, 191, 512])
torch.Size([4, 80, 512])
torch.Size([4, 260, 512])
torch.Size([4, 220, 512])
torch.Size([4, 51, 512])
torch.Size([4, 122, 512])
torch.Size([4, 57, 512])
torch.Size([4, 94, 512])
torch.Size([4, 48, 512])
torch.Size([4, 84, 512])
torch.Size([4, 60, 512])
torch.Size([4, 56, 512])
torch.Size([4, 17, 512])
torch.Size([4, 65, 512])
torch.Size([4, 83, 512])


  0%|          | 1201/450338 [00:18<1:34:33, 79.17it/s]

torch.Size([4, 29, 512])
torch.Size([4, 43, 512])
torch.Size([4, 69, 512])
torch.Size([4, 27, 512])
torch.Size([4, 26, 512])
torch.Size([4, 27, 512])
torch.Size([4, 44, 512])
torch.Size([4, 40, 512])
torch.Size([4, 36, 512])
torch.Size([4, 55, 512])
torch.Size([4, 21, 512])
torch.Size([4, 236, 512])
torch.Size([4, 166, 512])
torch.Size([4, 330, 512])
torch.Size([4, 72, 512])
torch.Size([4, 169, 512])
torch.Size([4, 127, 512])
torch.Size([4, 74, 512])


  0%|          | 1210/450338 [00:18<1:34:26, 79.26it/s]

torch.Size([4, 8, 512])
torch.Size([4, 87, 512])
torch.Size([4, 222, 512])
torch.Size([4, 228, 512])
torch.Size([4, 146, 512])
torch.Size([4, 176, 512])
torch.Size([4, 156, 512])
torch.Size([4, 174, 512])
torch.Size([4, 272, 512])
torch.Size([4, 108, 512])
torch.Size([4, 201, 512])
torch.Size([4, 302, 512])
torch.Size([4, 200, 512])


  0%|          | 1229/450338 [00:19<1:39:29, 75.23it/s]

torch.Size([4, 65, 512])
torch.Size([4, 195, 512])
torch.Size([4, 102, 512])
torch.Size([4, 121, 512])
torch.Size([4, 44, 512])
torch.Size([4, 28, 512])
torch.Size([4, 20, 512])
torch.Size([4, 164, 512])
torch.Size([4, 39, 512])
torch.Size([4, 23, 512])
torch.Size([4, 275, 512])
torch.Size([4, 169, 512])
torch.Size([4, 74, 512])
torch.Size([4, 209, 512])
torch.Size([4, 142, 512])
torch.Size([4, 270, 512])


  0%|          | 1247/450338 [00:19<1:35:53, 78.05it/s]

torch.Size([4, 128, 512])
torch.Size([4, 227, 512])
torch.Size([4, 27, 512])
torch.Size([4, 39, 512])
torch.Size([4, 132, 512])
torch.Size([4, 122, 512])
torch.Size([4, 50, 512])
torch.Size([4, 201, 512])
torch.Size([4, 105, 512])
torch.Size([4, 63, 512])
torch.Size([4, 144, 512])
torch.Size([4, 88, 512])
torch.Size([4, 135, 512])
torch.Size([4, 242, 512])
torch.Size([4, 108, 512])
torch.Size([4, 200, 512])
torch.Size([4, 19, 512])


  0%|          | 1265/450338 [00:19<1:30:46, 82.45it/s]

torch.Size([4, 140, 512])
torch.Size([4, 93, 512])
torch.Size([4, 146, 512])
torch.Size([4, 153, 512])
torch.Size([4, 70, 512])
torch.Size([4, 113, 512])
torch.Size([4, 65, 512])
torch.Size([4, 38, 512])
torch.Size([4, 81, 512])
torch.Size([4, 69, 512])
torch.Size([4, 162, 512])
torch.Size([4, 81, 512])
torch.Size([4, 107, 512])
torch.Size([4, 181, 512])
torch.Size([4, 92, 512])
torch.Size([4, 110, 512])
torch.Size([4, 166, 512])
torch.Size([4, 62, 512])
torch.Size([4, 135, 512])


  0%|          | 1283/450338 [00:19<1:29:21, 83.76it/s]

torch.Size([4, 107, 512])
torch.Size([4, 113, 512])
torch.Size([4, 148, 512])
torch.Size([4, 108, 512])
torch.Size([4, 223, 512])
torch.Size([4, 110, 512])
torch.Size([4, 91, 512])
torch.Size([4, 131, 512])
torch.Size([4, 178, 512])
torch.Size([4, 129, 512])
torch.Size([4, 10, 512])
torch.Size([4, 7, 512])
torch.Size([4, 178, 512])
torch.Size([4, 233, 512])
torch.Size([4, 361, 512])


  0%|          | 1292/450338 [00:19<1:43:37, 72.23it/s]

torch.Size([4, 360, 512])
torch.Size([4, 184, 512])
torch.Size([4, 276, 512])
torch.Size([4, 123, 512])
torch.Size([4, 137, 512])
torch.Size([4, 152, 512])
torch.Size([4, 103, 512])
torch.Size([4, 112, 512])
torch.Size([4, 195, 512])
torch.Size([4, 228, 512])
torch.Size([4, 115, 512])
torch.Size([4, 229, 512])
torch.Size([4, 190, 512])


  0%|          | 1308/450338 [00:20<1:43:24, 72.38it/s]

torch.Size([4, 256, 512])
torch.Size([4, 18, 512])
torch.Size([4, 9, 512])
torch.Size([4, 141, 512])
torch.Size([4, 110, 512])
torch.Size([4, 152, 512])
torch.Size([4, 220, 512])
torch.Size([4, 174, 512])
torch.Size([4, 123, 512])
torch.Size([4, 11, 512])
torch.Size([4, 219, 512])
torch.Size([4, 191, 512])
torch.Size([4, 248, 512])
torch.Size([4, 143, 512])
torch.Size([4, 164, 512])


  0%|          | 1324/450338 [00:20<1:45:05, 71.21it/s]

torch.Size([4, 205, 512])
torch.Size([4, 232, 512])
torch.Size([4, 185, 512])
torch.Size([4, 310, 512])
torch.Size([4, 79, 512])
torch.Size([4, 27, 512])
torch.Size([4, 100, 512])
torch.Size([4, 103, 512])
torch.Size([4, 175, 512])
torch.Size([4, 70, 512])
torch.Size([4, 150, 512])
torch.Size([4, 106, 512])
torch.Size([4, 127, 512])
torch.Size([4, 37, 512])
torch.Size([4, 117, 512])
torch.Size([4, 104, 512])


  0%|          | 1342/450338 [00:20<1:48:15, 69.12it/s]

torch.Size([4, 144, 512])
torch.Size([4, 104, 512])
torch.Size([4, 72, 512])
torch.Size([4, 8, 512])
torch.Size([4, 107, 512])
torch.Size([4, 265, 512])
torch.Size([4, 299, 512])
torch.Size([4, 223, 512])
torch.Size([4, 220, 512])
torch.Size([4, 263, 512])
torch.Size([4, 331, 512])
torch.Size([4, 10, 512])
torch.Size([4, 16, 512])
torch.Size([4, 10, 512])


  0%|          | 1352/450338 [00:20<1:40:23, 74.54it/s]

torch.Size([4, 14, 512])
torch.Size([4, 13, 512])
torch.Size([4, 11, 512])
torch.Size([4, 10, 512])
torch.Size([4, 10, 512])
torch.Size([4, 9, 512])
torch.Size([4, 185, 512])
torch.Size([4, 197, 512])
torch.Size([4, 203, 512])
torch.Size([4, 192, 512])
torch.Size([4, 238, 512])
torch.Size([4, 8, 512])
torch.Size([4, 8, 512])
torch.Size([4, 10, 512])
torch.Size([4, 11, 512])
torch.Size([4, 11, 512])
torch.Size([4, 9, 512])
torch.Size([4, 8, 512])
torch.Size([4, 252, 512])


  0%|          | 1370/450338 [00:20<1:44:21, 71.70it/s]

torch.Size([4, 187, 512])
torch.Size([4, 16, 512])
torch.Size([4, 216, 512])
torch.Size([4, 185, 512])
torch.Size([4, 204, 512])
torch.Size([4, 331, 512])
torch.Size([4, 201, 512])
torch.Size([4, 167, 512])
torch.Size([4, 279, 512])
torch.Size([4, 245, 512])
torch.Size([4, 291, 512])
torch.Size([4, 84, 512])


  0%|          | 1378/450338 [00:21<1:49:46, 68.16it/s]

torch.Size([4, 296, 512])
torch.Size([4, 71, 512])
torch.Size([4, 145, 512])
torch.Size([4, 9, 512])
torch.Size([4, 228, 512])
torch.Size([4, 213, 512])
torch.Size([4, 291, 512])
torch.Size([4, 255, 512])
torch.Size([4, 223, 512])
torch.Size([4, 302, 512])
torch.Size([4, 320, 512])


  0%|          | 1393/450338 [00:21<2:01:44, 61.46it/s]

torch.Size([4, 7, 512])
torch.Size([4, 146, 512])
torch.Size([4, 45, 512])
torch.Size([4, 37, 512])
torch.Size([4, 79, 512])
torch.Size([4, 147, 512])
torch.Size([4, 83, 512])
torch.Size([4, 404, 512])
torch.Size([4, 255, 512])
torch.Size([4, 330, 512])
torch.Size([4, 212, 512])
torch.Size([4, 324, 512])


  0%|          | 1410/450338 [00:21<1:49:40, 68.22it/s]

torch.Size([4, 151, 512])
torch.Size([4, 202, 512])
torch.Size([4, 10, 512])
torch.Size([4, 61, 512])
torch.Size([4, 124, 512])
torch.Size([4, 185, 512])
torch.Size([4, 97, 512])
torch.Size([4, 64, 512])
torch.Size([4, 108, 512])
torch.Size([4, 78, 512])
torch.Size([4, 68, 512])
torch.Size([4, 115, 512])
torch.Size([4, 115, 512])
torch.Size([4, 143, 512])
torch.Size([4, 155, 512])
torch.Size([4, 238, 512])
torch.Size([4, 17, 512])
torch.Size([4, 266, 512])


  0%|          | 1425/450338 [00:21<1:52:41, 66.39it/s]

torch.Size([4, 123, 512])
torch.Size([4, 58, 512])
torch.Size([4, 387, 512])
torch.Size([4, 8, 512])
torch.Size([4, 16, 512])
torch.Size([4, 35, 512])
torch.Size([4, 165, 512])
torch.Size([4, 215, 512])
torch.Size([4, 188, 512])
torch.Size([4, 352, 512])
torch.Size([4, 87, 512])
torch.Size([4, 96, 512])
torch.Size([4, 10, 512])
torch.Size([4, 8, 512])
torch.Size([4, 194, 512])


  0%|          | 1441/450338 [00:22<1:53:42, 65.80it/s]

torch.Size([4, 128, 512])
torch.Size([4, 266, 512])
torch.Size([4, 198, 512])
torch.Size([4, 236, 512])
torch.Size([4, 258, 512])
torch.Size([4, 291, 512])
torch.Size([4, 82, 512])
torch.Size([4, 172, 512])
torch.Size([4, 48, 512])
torch.Size([4, 226, 512])
torch.Size([4, 81, 512])
torch.Size([4, 63, 512])
torch.Size([4, 41, 512])


  0%|          | 1460/450338 [00:22<1:38:10, 76.20it/s]

torch.Size([4, 44, 512])
torch.Size([4, 47, 512])
torch.Size([4, 46, 512])
torch.Size([4, 62, 512])
torch.Size([4, 49, 512])
torch.Size([4, 41, 512])
torch.Size([4, 78, 512])
torch.Size([4, 52, 512])
torch.Size([4, 37, 512])
torch.Size([4, 44, 512])
torch.Size([4, 50, 512])
torch.Size([4, 48, 512])
torch.Size([4, 48, 512])
torch.Size([4, 266, 512])
torch.Size([4, 59, 512])
torch.Size([4, 139, 512])
torch.Size([4, 284, 512])
torch.Size([4, 24, 512])
torch.Size([4, 35, 512])


  0%|          | 1475/450338 [00:22<1:55:01, 65.03it/s]

torch.Size([4, 35, 512])
torch.Size([4, 35, 512])
torch.Size([4, 35, 512])
torch.Size([4, 35, 512])
torch.Size([4, 35, 512])
torch.Size([4, 102, 512])
torch.Size([4, 136, 512])
torch.Size([4, 54, 512])
torch.Size([4, 393, 512])
torch.Size([4, 113, 512])
torch.Size([4, 90, 512])
torch.Size([4, 101, 512])
torch.Size([4, 189, 512])
torch.Size([4, 239, 512])
torch.Size([4, 10, 512])
torch.Size([4, 270, 512])


  0%|          | 1482/450338 [00:22<1:59:55, 62.38it/s]

torch.Size([4, 368, 512])
torch.Size([4, 112, 512])
torch.Size([4, 59, 512])
torch.Size([4, 248, 512])
torch.Size([4, 266, 512])
torch.Size([4, 123, 512])
torch.Size([4, 205, 512])
torch.Size([4, 7, 512])
torch.Size([4, 93, 512])
torch.Size([4, 10, 512])
torch.Size([4, 178, 512])
torch.Size([4, 64, 512])


  0%|          | 1504/450338 [00:22<1:54:43, 65.21it/s]

torch.Size([4, 313, 512])
torch.Size([4, 103, 512])
torch.Size([4, 114, 512])
torch.Size([4, 9, 512])
torch.Size([4, 285, 512])
torch.Size([4, 189, 512])
torch.Size([4, 78, 512])
torch.Size([4, 152, 512])
torch.Size([4, 11, 512])
torch.Size([4, 214, 512])
torch.Size([4, 165, 512])
torch.Size([4, 161, 512])
torch.Size([4, 10, 512])
torch.Size([4, 147, 512])
torch.Size([4, 148, 512])


  0%|          | 1512/450338 [00:23<1:50:30, 67.69it/s]

torch.Size([4, 12, 512])
torch.Size([4, 203, 512])
torch.Size([4, 263, 512])
torch.Size([4, 178, 512])
torch.Size([4, 86, 512])
torch.Size([4, 108, 512])
torch.Size([4, 105, 512])
torch.Size([4, 178, 512])
torch.Size([4, 125, 512])
torch.Size([4, 84, 512])
torch.Size([4, 211, 512])
torch.Size([4, 196, 512])
torch.Size([4, 377, 512])
torch.Size([4, 7, 512])


  0%|          | 1526/450338 [00:23<2:03:45, 60.44it/s]

torch.Size([4, 410, 512])
torch.Size([4, 183, 512])
torch.Size([4, 222, 512])
torch.Size([4, 287, 512])
torch.Size([4, 289, 512])
torch.Size([4, 103, 512])
torch.Size([4, 133, 512])
torch.Size([4, 75, 512])
torch.Size([4, 162, 512])
torch.Size([4, 65, 512])
torch.Size([4, 143, 512])
torch.Size([4, 177, 512])
torch.Size([4, 117, 512])


  0%|          | 1541/450338 [00:23<2:02:45, 60.93it/s]

torch.Size([4, 152, 512])
torch.Size([4, 182, 512])
torch.Size([4, 273, 512])
torch.Size([4, 125, 512])
torch.Size([4, 111, 512])
torch.Size([4, 126, 512])
torch.Size([4, 211, 512])
torch.Size([4, 195, 512])
torch.Size([4, 406, 512])
torch.Size([4, 178, 512])
torch.Size([4, 323, 512])


  0%|          | 1548/450338 [00:23<1:59:51, 62.40it/s]

torch.Size([4, 94, 512])
torch.Size([4, 214, 512])
torch.Size([4, 61, 512])
torch.Size([4, 191, 512])
torch.Size([4, 74, 512])
torch.Size([4, 148, 512])
torch.Size([4, 214, 512])
torch.Size([4, 268, 512])
torch.Size([4, 14, 512])
torch.Size([4, 235, 512])
torch.Size([4, 219, 512])


  0%|          | 1558/450338 [00:23<1:55:11, 64.94it/s]

torch.Size([4, 487, 512])
torch.Size([4, 276, 512])
torch.Size([4, 267, 512])
torch.Size([4, 273, 512])
torch.Size([4, 228, 512])





OutOfMemoryError: CUDA out of memory. Tried to allocate 648.00 MiB (GPU 0; 5.67 GiB total capacity; 4.90 GiB already allocated; 187.69 MiB free; 5.33 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [6]:
torch.cuda.empty_cache()