<font size="5">Import Libraries</font>

In [12]:
import torch
from torchvision import transforms as T
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
from tqdm import tqdm
from dalle_pytorch import DALLE, OpenAIDiscreteVAE, DiscreteVAE
from dalle_pytorch.tokenizer import SimpleTokenizer
from torchvision.datasets.coco import CocoCaptions

<font size="5">Setting Dataset & Training Parameters</font>

In [13]:
# Change your input size here
input_image_size = 256

# Change your batch size here
batch_size = 1

# Change your epoch here
epoch = 1

# Change your train image root path here
train_img_path = "./train2014/"

# Change your train annot json path here
train_annot_path = "./annotations/captions_train2014.json"

# Change your device ("cpu" or "cuda")
device = "cuda"

# Change your vae model save path here (ends with ".pth")
vae_save_path = "./vae.pth"

# Change your dalle model save path here (ends with ".pth")
dalle_save_path = "./dalle.pth"

<font size="5">Data Preprocessing</font>

In [14]:
transform = T.Compose([
    T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
    T.Resize(input_image_size),
    T.CenterCrop(input_image_size),
    T.ToTensor()
])

train_data = CocoCaptions(
    root=train_img_path,
    annFile=train_annot_path,
    transform=transform
)

loading annotations into memory...
Done (t=0.90s)
creating index...
index created!


<font size="5">Create VAE Model</font>

In [15]:
vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,
    num_tokens = 8192,
    codebook_dim = 1024,
    hidden_dim = 64,
    num_resnet_blocks = 1,
    temperature = 0.9
).to(device)

tokenizer = SimpleTokenizer()

In [16]:
def get_trainable_params(model):
    return [params for params in model.parameters() if params.requires_grad]

In [18]:
train_size = len(train_data)
idx_list = range(0, train_size, batch_size)

opt = Adam(
    get_trainable_params(vae),
    lr = 3e-4,
    # weight_decay=0.01,
    # betas = (0.9, 0.999)
)
sched = ReduceLROnPlateau(
    opt,
    mode="min",
    factor=0.5,
    patience=10,
    cooldown=10,
    min_lr=1e-6,
    verbose=True,
)

for curr_epoch in range(epoch):
    print("Run training vae ...")
    print(f"Epoch {curr_epoch+1} / {epoch}")
    
    for batch_idx in tqdm(idx_list):
        if (batch_idx + batch_size) > train_size - 1:
            iter_idx = range(batch_idx, train_size, 1)
        else:
            iter_idx = range(batch_idx, batch_idx+batch_size, 1)

        batch_len = 0
        total_loss = torch.tensor(0., device=device)

        for curr_idx in iter_idx:
            image, _ = train_data[curr_idx]
            image = image.unsqueeze(0).type(torch.FloatTensor).to(device)
            
            # for text in texts:
            if total_loss == torch.tensor(0., device=device):
                total_loss = vae(image, return_loss=True)
            else:
                total_loss += vae(image, return_loss=True)
            batch_len += 1
                
        avg_loss = total_loss / batch_len

        opt.zero_grad()
        avg_loss.backward()
        opt.step()
        
        if batch_idx % 100 == 0:
            torch.save(vae.state_dict(), vae_save_path)
            print(f"average loss: {avg_loss.data}")
        
    sched.step(avg_loss)

torch.save(vae.state_dict(), vae_save_path)

Run training vae ...
Epoch 1 / 1


  0%|          | 4/82783 [00:01<4:45:57,  4.82it/s] 

average loss: 0.46610569953918457


  0%|          | 103/82783 [00:05<1:26:12, 15.98it/s]

average loss: 0.06004039943218231


  0%|          | 205/82783 [00:10<1:18:31, 17.53it/s]

average loss: 0.029740195721387863


  0%|          | 304/82783 [00:14<1:17:14, 17.80it/s]

average loss: 0.06832312792539597


  0%|          | 403/82783 [00:19<1:25:36, 16.04it/s]

average loss: 0.09563994407653809


  1%|          | 505/82783 [00:24<1:18:03, 17.57it/s]

average loss: 0.07715556770563126


  1%|          | 604/82783 [00:28<1:17:25, 17.69it/s]

average loss: 0.013978717848658562


  1%|          | 703/82783 [00:33<1:24:46, 16.14it/s]

average loss: 0.09663800895214081


  1%|          | 805/82783 [00:37<1:17:54, 17.54it/s]

average loss: 0.0560746006667614


  1%|          | 904/82783 [00:42<1:18:02, 17.48it/s]

average loss: 0.12497895210981369


  1%|          | 1003/82783 [00:46<1:24:34, 16.11it/s]

average loss: 0.04397071525454521


  1%|▏         | 1105/82783 [00:51<1:21:10, 16.77it/s]

average loss: 0.06164714694023132


  1%|▏         | 1203/82783 [00:56<1:30:33, 15.01it/s]

average loss: 0.04689127951860428


  2%|▏         | 1305/82783 [01:00<1:15:37, 17.96it/s]

average loss: 0.04399016499519348


  2%|▏         | 1404/82783 [01:05<1:16:45, 17.67it/s]

average loss: 0.025611644610762596


  2%|▏         | 1503/82783 [01:09<1:24:04, 16.11it/s]

average loss: 0.02462626062333584


  2%|▏         | 1605/82783 [01:14<1:14:52, 18.07it/s]

average loss: 0.013256857171654701


  2%|▏         | 1704/82783 [01:19<1:18:58, 17.11it/s]

average loss: 0.05003911256790161


  2%|▏         | 1803/82783 [01:23<1:23:49, 16.10it/s]

average loss: 0.017690733075141907


  2%|▏         | 1905/82783 [01:28<1:19:07, 17.04it/s]

average loss: 0.036749787628650665


  2%|▏         | 2004/82783 [01:32<1:17:25, 17.39it/s]

average loss: 0.016511939465999603


  3%|▎         | 2103/82783 [01:37<1:23:57, 16.02it/s]

average loss: 0.07671140879392624


  3%|▎         | 2205/82783 [01:41<1:16:01, 17.66it/s]

average loss: 0.0179927796125412


  3%|▎         | 2304/82783 [01:46<1:14:47, 17.93it/s]

average loss: 0.01875326782464981


  3%|▎         | 2403/82783 [01:50<1:24:18, 15.89it/s]

average loss: 0.03837057203054428


  3%|▎         | 2505/82783 [01:55<1:14:26, 17.97it/s]

average loss: 0.07276061177253723


  3%|▎         | 2604/82783 [01:59<1:14:31, 17.93it/s]

average loss: 0.034358225762844086


  3%|▎         | 2703/82783 [02:04<1:21:17, 16.42it/s]

average loss: 0.03659338131546974


  3%|▎         | 2805/82783 [02:09<1:16:18, 17.47it/s]

average loss: 0.03683756664395332


  4%|▎         | 2904/82783 [02:13<1:17:36, 17.15it/s]

average loss: 0.03579491004347801


  4%|▎         | 3003/82783 [02:17<1:22:17, 16.16it/s]

average loss: 0.01620299369096756


  4%|▍         | 3105/82783 [02:22<1:13:49, 17.99it/s]

average loss: 0.04115655645728111


  4%|▍         | 3204/82783 [02:26<1:13:56, 17.94it/s]

average loss: 0.03105243109166622


  4%|▍         | 3303/82783 [02:31<1:26:09, 15.37it/s]

average loss: 0.008071205578744411


  4%|▍         | 3405/82783 [02:36<1:18:08, 16.93it/s]

average loss: 0.02027730457484722


  4%|▍         | 3504/82783 [02:40<1:14:31, 17.73it/s]

average loss: 0.057387690991163254


  4%|▍         | 3603/82783 [02:45<1:20:58, 16.30it/s]

average loss: 0.019369065761566162


  4%|▍         | 3705/82783 [02:49<1:16:19, 17.27it/s]

average loss: 0.015221066772937775


  5%|▍         | 3804/82783 [02:54<1:10:34, 18.65it/s]

average loss: 0.018092140555381775


  5%|▍         | 3903/82783 [02:58<1:19:20, 16.57it/s]

average loss: 0.03918599709868431


  5%|▍         | 4005/82783 [03:02<1:10:02, 18.74it/s]

average loss: 0.029703032225370407


  5%|▍         | 4104/82783 [03:07<1:09:57, 18.74it/s]

average loss: 0.012872490100562572


  5%|▌         | 4206/82783 [03:11<1:11:35, 18.29it/s]

average loss: 0.04492416977882385


  5%|▌         | 4305/82783 [03:15<1:08:39, 19.05it/s]

average loss: 0.03066675178706646


  5%|▌         | 4404/82783 [03:20<1:12:36, 17.99it/s]

average loss: 0.04347613453865051


  5%|▌         | 4506/82783 [03:24<1:10:22, 18.54it/s]

average loss: 0.03136490657925606


  6%|▌         | 4605/82783 [03:28<1:10:29, 18.48it/s]

average loss: 0.07248479127883911


  6%|▌         | 4704/82783 [03:33<1:09:52, 18.62it/s]

average loss: 0.011797132901847363


  6%|▌         | 4806/82783 [03:37<1:08:56, 18.85it/s]

average loss: 0.030740071088075638


  6%|▌         | 4905/82783 [03:41<1:11:06, 18.25it/s]

average loss: 0.03793008625507355


  6%|▌         | 5004/82783 [03:46<1:10:16, 18.45it/s]

average loss: 0.044777762144804


  6%|▌         | 5103/82783 [03:50<1:16:27, 16.93it/s]

average loss: 0.02152317389845848


  6%|▋         | 5205/82783 [03:54<1:08:37, 18.84it/s]

average loss: 0.021877892315387726


  6%|▋         | 5304/82783 [03:59<1:10:03, 18.43it/s]

average loss: 0.036141298711299896


  7%|▋         | 5406/82783 [04:03<1:08:24, 18.85it/s]

average loss: 0.03446599096059799


  7%|▋         | 5505/82783 [04:07<1:11:54, 17.91it/s]

average loss: 0.026459667831659317


  7%|▋         | 5604/82783 [04:12<1:09:57, 18.39it/s]

average loss: 0.060072850435972214


  7%|▋         | 5706/82783 [04:16<1:09:22, 18.52it/s]

average loss: 0.04253831505775452


  7%|▋         | 5805/82783 [04:20<1:10:12, 18.28it/s]

average loss: 0.036804888397455215


  7%|▋         | 5904/82783 [04:25<1:08:30, 18.70it/s]

average loss: 0.05650933086872101


  7%|▋         | 6003/82783 [04:29<1:16:39, 16.69it/s]

average loss: 0.017798280343413353


  7%|▋         | 6105/82783 [04:33<1:10:16, 18.19it/s]

average loss: 0.006726976018399


  7%|▋         | 6204/82783 [04:38<1:06:49, 19.10it/s]

average loss: 0.010352704674005508


  8%|▊         | 6306/82783 [04:42<1:08:33, 18.59it/s]

average loss: 0.02247733250260353


  8%|▊         | 6405/82783 [04:46<1:09:14, 18.38it/s]

average loss: 0.04527439922094345


  8%|▊         | 6504/82783 [04:51<1:10:58, 17.91it/s]

average loss: 0.016567213460803032


  8%|▊         | 6603/82783 [04:55<1:14:00, 17.16it/s]

average loss: 0.045793406665325165


  8%|▊         | 6705/82783 [04:59<1:08:07, 18.61it/s]

average loss: 0.031387388706207275


  8%|▊         | 6804/82783 [05:04<1:09:17, 18.28it/s]

average loss: 0.01743226684629917


  8%|▊         | 6906/82783 [05:08<1:08:18, 18.51it/s]

average loss: 0.004185494966804981


  8%|▊         | 7005/82783 [05:12<1:07:16, 18.77it/s]

average loss: 0.007251922972500324


  9%|▊         | 7104/82783 [05:17<1:06:46, 18.89it/s]

average loss: 0.01356271468102932


  9%|▊         | 7203/82783 [05:21<1:16:31, 16.46it/s]

average loss: 0.029585829004645348


  9%|▉         | 7305/82783 [05:25<1:08:18, 18.42it/s]

average loss: 0.02634366601705551


  9%|▉         | 7404/82783 [05:30<1:09:52, 17.98it/s]

average loss: 0.013065584003925323


  9%|▉         | 7503/82783 [05:34<1:15:45, 16.56it/s]

average loss: 0.017728671431541443


  9%|▉         | 7605/82783 [05:39<1:08:31, 18.28it/s]

average loss: 0.022080719470977783


  9%|▉         | 7704/82783 [05:43<1:08:14, 18.34it/s]

average loss: 0.05149441212415695


  9%|▉         | 7803/82783 [05:47<1:16:35, 16.32it/s]

average loss: 0.008920286782085896


 10%|▉         | 7905/82783 [05:52<1:09:48, 17.88it/s]

average loss: 0.08810488879680634


 10%|▉         | 8004/82783 [05:56<1:07:46, 18.39it/s]

average loss: 0.01586325280368328


 10%|▉         | 8103/82783 [06:01<1:14:36, 16.68it/s]

average loss: 0.032851822674274445


 10%|▉         | 8205/82783 [06:05<1:06:33, 18.67it/s]

average loss: 0.008729960769414902


 10%|█         | 8304/82783 [06:09<1:09:03, 17.98it/s]

average loss: 0.0183913242071867


 10%|█         | 8406/82783 [06:14<1:08:53, 18.00it/s]

average loss: 0.021463517099618912


 10%|█         | 8505/82783 [06:18<1:07:15, 18.40it/s]

average loss: 0.042104385793209076


 10%|█         | 8604/82783 [06:23<1:08:47, 17.97it/s]

average loss: 0.021290387958288193


 11%|█         | 8706/82783 [06:27<1:07:54, 18.18it/s]

average loss: 0.016619069501757622


 11%|█         | 8805/82783 [06:31<1:08:44, 17.94it/s]

average loss: 0.021219482645392418


 11%|█         | 8904/82783 [06:36<1:06:59, 18.38it/s]

average loss: 0.009769181720912457


 11%|█         | 9003/82783 [06:40<1:14:30, 16.50it/s]

average loss: 0.01741671748459339


 11%|█         | 9105/82783 [06:45<1:08:57, 17.81it/s]

average loss: 0.03753097355365753


 11%|█         | 9204/82783 [06:49<1:10:08, 17.48it/s]

average loss: 0.04821419343352318


 11%|█         | 9306/82783 [06:54<1:07:28, 18.15it/s]

average loss: 0.028212934732437134


 11%|█▏        | 9405/82783 [06:58<1:09:46, 17.53it/s]

average loss: 0.024816524237394333


 11%|█▏        | 9504/82783 [07:03<1:08:14, 17.90it/s]

average loss: 0.024922452867031097


 12%|█▏        | 9606/82783 [07:07<1:07:21, 18.11it/s]

average loss: 0.01358439028263092


 12%|█▏        | 9705/82783 [07:11<1:06:36, 18.28it/s]

average loss: 0.029187319800257683


 12%|█▏        | 9804/82783 [07:16<1:07:16, 18.08it/s]

average loss: 0.00847986526787281


 12%|█▏        | 9903/82783 [07:20<1:13:37, 16.50it/s]

average loss: 0.018467631191015244


 12%|█▏        | 10005/82783 [07:25<1:05:39, 18.47it/s]

average loss: 0.033538106828927994


 12%|█▏        | 10104/82783 [07:29<1:08:15, 17.75it/s]

average loss: 0.031062578782439232


 12%|█▏        | 10203/82783 [07:34<1:14:43, 16.19it/s]

average loss: 0.01878046803176403


 12%|█▏        | 10305/82783 [07:38<1:06:38, 18.13it/s]

average loss: 0.012696526944637299


 13%|█▎        | 10404/82783 [07:43<1:08:27, 17.62it/s]

average loss: 0.046350203454494476


 13%|█▎        | 10506/82783 [07:47<1:05:43, 18.33it/s]

average loss: 0.015971627086400986


 13%|█▎        | 10605/82783 [07:52<1:06:59, 17.95it/s]

average loss: 0.016612041741609573


 13%|█▎        | 10704/82783 [07:56<1:05:58, 18.21it/s]

average loss: 0.031300850212574005


 13%|█▎        | 10803/82783 [08:01<1:12:37, 16.52it/s]

average loss: 0.011986421421170235


 13%|█▎        | 10905/82783 [08:05<1:06:00, 18.15it/s]

average loss: 0.019859924912452698


 13%|█▎        | 11004/82783 [08:10<1:07:28, 17.73it/s]

average loss: 0.015832584351301193


 13%|█▎        | 11103/82783 [08:14<1:14:50, 15.96it/s]

average loss: 0.015882940962910652


 14%|█▎        | 11205/82783 [08:19<1:06:18, 17.99it/s]

average loss: 0.01240652333945036


 14%|█▎        | 11304/82783 [08:23<1:08:01, 17.51it/s]

average loss: 0.021751387044787407


 14%|█▍        | 11403/82783 [08:28<1:13:38, 16.15it/s]

average loss: 0.02121940813958645


 14%|█▍        | 11505/82783 [08:32<1:07:05, 17.71it/s]

average loss: 0.020228343084454536


 14%|█▍        | 11604/82783 [08:37<1:06:21, 17.88it/s]

average loss: 0.025047849863767624


 14%|█▍        | 11703/82783 [08:41<1:14:21, 15.93it/s]

average loss: 0.013348198495805264


 14%|█▍        | 11805/82783 [08:46<1:07:47, 17.45it/s]

average loss: 0.019159087911248207


 14%|█▍        | 11904/82783 [08:50<1:07:52, 17.40it/s]

average loss: 0.04077925905585289


 14%|█▍        | 12003/82783 [08:55<1:13:58, 15.95it/s]

average loss: 0.028821051120758057


 15%|█▍        | 12105/82783 [08:59<1:06:11, 17.80it/s]

average loss: 0.030317086726427078


 15%|█▍        | 12204/82783 [09:04<1:04:51, 18.14it/s]

average loss: 0.015507657080888748


 15%|█▍        | 12303/82783 [09:08<1:10:50, 16.58it/s]

average loss: 0.05724100023508072


 15%|█▍        | 12405/82783 [09:13<1:06:15, 17.71it/s]

average loss: 0.015751678496599197


 15%|█▌        | 12504/82783 [09:17<1:04:17, 18.22it/s]

average loss: 0.008956676349043846


 15%|█▌        | 12606/82783 [09:22<1:04:53, 18.02it/s]

average loss: 0.03668142110109329


 15%|█▌        | 12705/82783 [09:26<1:05:57, 17.71it/s]

average loss: 0.012034284882247448


 15%|█▌        | 12804/82783 [09:31<1:04:10, 18.17it/s]

average loss: 0.008262522518634796


 16%|█▌        | 12903/82783 [09:35<1:11:11, 16.36it/s]

average loss: 0.023275218904018402


 16%|█▌        | 13005/82783 [09:40<1:05:11, 17.84it/s]

average loss: 0.014834616333246231


 16%|█▌        | 13104/82783 [09:44<1:06:11, 17.55it/s]

average loss: 0.01390252448618412


 16%|█▌        | 13203/82783 [09:49<1:12:06, 16.08it/s]

average loss: 0.028933130204677582


 16%|█▌        | 13305/82783 [09:54<1:05:49, 17.59it/s]

average loss: 0.02205967903137207


 16%|█▌        | 13404/82783 [09:58<1:05:48, 17.57it/s]

average loss: 0.026346027851104736


 16%|█▋        | 13503/82783 [10:03<1:11:09, 16.23it/s]

average loss: 0.017605453729629517


 16%|█▋        | 13605/82783 [10:07<1:05:23, 17.63it/s]

average loss: 0.024502307176589966


 17%|█▋        | 13704/82783 [10:12<1:04:21, 17.89it/s]

average loss: 0.03573056682944298


 17%|█▋        | 13803/82783 [10:16<1:13:13, 15.70it/s]

average loss: 0.022607143968343735


 17%|█▋        | 13905/82783 [10:21<1:06:04, 17.37it/s]

average loss: 0.01144975796341896


 17%|█▋        | 14004/82783 [10:25<1:05:46, 17.43it/s]

average loss: 0.032572608441114426


 17%|█▋        | 14103/82783 [10:30<1:09:06, 16.56it/s]

average loss: 0.009844179265201092


 17%|█▋        | 14205/82783 [10:35<1:05:11, 17.53it/s]

average loss: 0.018497686833143234


 17%|█▋        | 14304/82783 [10:39<1:04:14, 17.77it/s]

average loss: 0.022641733288764954


 17%|█▋        | 14403/82783 [10:44<1:10:21, 16.20it/s]

average loss: 0.033323340117931366


 18%|█▊        | 14505/82783 [10:48<1:04:17, 17.70it/s]

average loss: 0.03381817787885666


 18%|█▊        | 14604/82783 [10:53<1:03:27, 17.91it/s]

average loss: 0.01629612408578396


 18%|█▊        | 14703/82783 [10:57<1:11:13, 15.93it/s]

average loss: 0.022225316613912582


 18%|█▊        | 14805/82783 [11:02<1:02:47, 18.04it/s]

average loss: 0.022602630779147148


 18%|█▊        | 14904/82783 [11:06<1:03:24, 17.84it/s]

average loss: 0.031176356598734856


 18%|█▊        | 15003/82783 [11:11<1:10:54, 15.93it/s]

average loss: 0.013370986096560955


 18%|█▊        | 15105/82783 [11:16<1:05:39, 17.18it/s]

average loss: 0.013395296409726143


 18%|█▊        | 15204/82783 [11:20<1:05:20, 17.24it/s]

average loss: 0.012830889783799648


 18%|█▊        | 15303/82783 [11:25<1:08:52, 16.33it/s]

average loss: 0.02144380286335945


 19%|█▊        | 15405/82783 [11:29<1:02:01, 18.10it/s]

average loss: 0.009335096925497055


 19%|█▊        | 15504/82783 [11:34<1:03:03, 17.78it/s]

average loss: 0.012560434639453888


 19%|█▉        | 15603/82783 [11:39<1:11:29, 15.66it/s]

average loss: 0.022490879520773888


 19%|█▉        | 15705/82783 [11:43<1:03:48, 17.52it/s]

average loss: 0.005219507031142712


 19%|█▉        | 15803/82783 [11:48<1:10:41, 15.79it/s]

average loss: 0.008537168614566326


 19%|█▉        | 15905/82783 [11:52<1:02:10, 17.93it/s]

average loss: 0.00991239957511425


 19%|█▉        | 16004/82783 [11:57<1:04:15, 17.32it/s]

average loss: 0.04276920482516289


 19%|█▉        | 16103/82783 [12:01<1:09:29, 15.99it/s]

average loss: 0.016533860936760902


 20%|█▉        | 16205/82783 [12:06<1:02:01, 17.89it/s]

average loss: 0.026598608121275902


 20%|█▉        | 16304/82783 [12:11<1:02:20, 17.77it/s]

average loss: 0.04250791668891907


 20%|█▉        | 16403/82783 [12:15<1:08:07, 16.24it/s]

average loss: 0.01386189740151167


 20%|█▉        | 16504/82783 [12:20<1:02:52, 17.57it/s]

average loss: 0.004724396392703056


 20%|██        | 16603/82783 [12:24<1:07:46, 16.27it/s]

average loss: 0.015321771614253521


 20%|██        | 16705/82783 [12:29<1:03:09, 17.44it/s]

average loss: 0.03203591704368591


 20%|██        | 16804/82783 [12:33<1:01:49, 17.79it/s]

average loss: 0.028859248384833336


 20%|██        | 16903/82783 [12:38<1:07:07, 16.36it/s]

average loss: 0.023223720490932465


 21%|██        | 17005/82783 [12:43<1:02:00, 17.68it/s]

average loss: 0.01898178458213806


 21%|██        | 17104/82783 [12:47<1:04:59, 16.84it/s]

average loss: 0.006767936050891876


 21%|██        | 17203/82783 [12:52<1:08:06, 16.05it/s]

average loss: 0.008346865884959698


 21%|██        | 17305/82783 [12:56<1:03:22, 17.22it/s]

average loss: 0.00728354137390852


 21%|██        | 17404/82783 [13:01<1:01:59, 17.58it/s]

average loss: 0.014820532873272896


 21%|██        | 17503/82783 [13:06<1:06:12, 16.43it/s]

average loss: 0.00733990129083395


 21%|██▏       | 17605/82783 [13:10<1:01:42, 17.60it/s]

average loss: 0.005988049320876598


 21%|██▏       | 17704/82783 [13:15<1:01:17, 17.69it/s]

average loss: 0.022694196552038193


 22%|██▏       | 17803/82783 [13:19<1:06:34, 16.27it/s]

average loss: 0.00877899955958128


 22%|██▏       | 17905/82783 [13:24<1:06:15, 16.32it/s]

average loss: 0.02430647611618042


 22%|██▏       | 18004/82783 [13:29<59:41, 18.09it/s]  

average loss: 0.006737608462572098


 22%|██▏       | 18103/82783 [13:33<1:06:39, 16.17it/s]

average loss: 0.02002817578613758


 22%|██▏       | 18205/82783 [13:38<1:01:29, 17.50it/s]

average loss: 0.0160063486546278


 22%|██▏       | 18304/82783 [13:42<1:00:23, 17.79it/s]

average loss: 0.019614076241850853


 22%|██▏       | 18403/82783 [13:47<1:06:54, 16.04it/s]

average loss: 0.01564989611506462


 22%|██▏       | 18505/82783 [13:52<1:00:47, 17.62it/s]

average loss: 0.043861549347639084


 22%|██▏       | 18603/82783 [13:56<1:08:28, 15.62it/s]

average loss: 0.005639529787003994


 23%|██▎       | 18705/82783 [14:01<1:01:08, 17.47it/s]

average loss: 0.02040502242743969


 23%|██▎       | 18804/82783 [14:05<1:00:04, 17.75it/s]

average loss: 0.019344553351402283


 23%|██▎       | 18903/82783 [14:10<1:07:45, 15.71it/s]

average loss: 0.015279032289981842


 23%|██▎       | 19005/82783 [14:15<1:00:29, 17.57it/s]

average loss: 0.02534370869398117


 23%|██▎       | 19104/82783 [14:19<59:36, 17.80it/s]  

average loss: 0.052909284830093384


 23%|██▎       | 19203/82783 [14:24<1:06:39, 15.90it/s]

average loss: 0.015530755743384361


 23%|██▎       | 19305/82783 [14:29<59:29, 17.79it/s]  

average loss: 0.008908826857805252


 23%|██▎       | 19404/82783 [14:33<57:58, 18.22it/s]  

average loss: 0.018185043707489967


 24%|██▎       | 19503/82783 [14:38<1:05:43, 16.05it/s]

average loss: 0.002595354802906513


 24%|██▎       | 19605/82783 [14:42<59:24, 17.73it/s]  

average loss: 0.022908657789230347


 24%|██▍       | 19704/82783 [14:47<59:14, 17.74it/s]  

average loss: 0.011222699657082558


 24%|██▍       | 19803/82783 [14:52<1:06:34, 15.77it/s]

average loss: 0.01584606245160103


 24%|██▍       | 19905/82783 [14:56<58:51, 17.81it/s]  

average loss: 0.010782657191157341


 24%|██▍       | 20004/82783 [15:01<1:00:18, 17.35it/s]

average loss: 0.026201676577329636


 24%|██▍       | 20103/82783 [15:05<1:04:35, 16.17it/s]

average loss: 0.008105037733912468


 24%|██▍       | 20205/82783 [15:10<58:44, 17.75it/s]  

average loss: 0.005127348005771637


 25%|██▍       | 20304/82783 [15:15<1:00:05, 17.33it/s]

average loss: 0.028143685311079025


 25%|██▍       | 20403/82783 [15:19<1:03:14, 16.44it/s]

average loss: 0.031234828755259514


 25%|██▍       | 20503/82783 [15:24<1:05:17, 15.90it/s]

average loss: 0.011586226522922516


 25%|██▍       | 20605/82783 [15:29<1:00:04, 17.25it/s]

average loss: 0.017048951238393784


 25%|██▌       | 20704/82783 [15:33<59:15, 17.46it/s]  

average loss: 0.020176716148853302


 25%|██▌       | 20803/82783 [15:38<1:03:28, 16.27it/s]

average loss: 0.00651866290718317


 25%|██▌       | 20905/82783 [15:42<57:48, 17.84it/s]  

average loss: 0.008485055528581142


 25%|██▌       | 21004/82783 [15:47<58:39, 17.55it/s]  

average loss: 0.0075258477590978146


 25%|██▌       | 21103/82783 [15:52<1:04:58, 15.82it/s]

average loss: 0.010081851854920387


 26%|██▌       | 21205/82783 [15:56<59:17, 17.31it/s]  

average loss: 0.02236178144812584


 26%|██▌       | 21304/82783 [16:01<58:06, 17.63it/s]  

average loss: 0.014324914664030075


 26%|██▌       | 21403/82783 [16:06<1:02:38, 16.33it/s]

average loss: 0.004794283304363489


 26%|██▌       | 21505/82783 [16:10<57:48, 17.67it/s]  

average loss: 0.012643745169043541


 26%|██▌       | 21604/82783 [16:15<58:14, 17.51it/s]  

average loss: 0.0169321671128273


 26%|██▌       | 21703/82783 [16:19<1:03:13, 16.10it/s]

average loss: 0.01008526049554348


 26%|██▋       | 21805/82783 [16:24<57:59, 17.53it/s]  

average loss: 0.0147326048463583


 26%|██▋       | 21904/82783 [16:29<57:31, 17.64it/s]  

average loss: 0.005840163677930832


 27%|██▋       | 22003/82783 [16:33<1:04:26, 15.72it/s]

average loss: 0.016640400514006615


 27%|██▋       | 22105/82783 [16:38<56:56, 17.76it/s]  

average loss: 0.00716942735016346


 27%|██▋       | 22204/82783 [16:43<56:27, 17.88it/s]  

average loss: 0.006835078354924917


 27%|██▋       | 22303/82783 [16:47<1:02:39, 16.09it/s]

average loss: 0.017844725400209427


 27%|██▋       | 22405/82783 [16:52<58:17, 17.26it/s]  

average loss: 0.0200924314558506


 27%|██▋       | 22504/82783 [16:56<57:55, 17.34it/s]  

average loss: 0.012178463861346245


 27%|██▋       | 22603/82783 [17:01<1:00:45, 16.51it/s]

average loss: 0.017181754112243652


 27%|██▋       | 22705/82783 [17:06<56:47, 17.63it/s]  

average loss: 0.031424473971128464


 28%|██▊       | 22804/82783 [17:10<56:41, 17.64it/s]  

average loss: 0.012820884585380554


 28%|██▊       | 22903/82783 [17:15<1:01:10, 16.32it/s]

average loss: 0.05077540501952171


 28%|██▊       | 23005/82783 [17:20<56:03, 17.77it/s]  

average loss: 0.009235281497240067


 28%|██▊       | 23104/82783 [17:24<56:33, 17.59it/s]  

average loss: 0.028137080371379852


 28%|██▊       | 23203/82783 [17:29<1:03:02, 15.75it/s]

average loss: 0.013477346859872341


 28%|██▊       | 23305/82783 [17:33<55:52, 17.74it/s]  

average loss: 0.006281513255089521


 28%|██▊       | 23404/82783 [17:38<56:09, 17.62it/s]  

average loss: 0.016585547477006912


 28%|██▊       | 23503/82783 [17:43<1:02:14, 15.88it/s]

average loss: 0.004784043878316879


 29%|██▊       | 23605/82783 [17:47<56:16, 17.53it/s]  

average loss: 0.013357138261198997


 29%|██▊       | 23704/82783 [17:52<57:01, 17.27it/s]  

average loss: 0.01783263310790062


 29%|██▉       | 23803/82783 [17:57<1:01:46, 15.91it/s]

average loss: 0.020043058320879936


 29%|██▉       | 23905/82783 [18:01<56:11, 17.46it/s]  

average loss: 0.006175223737955093


 29%|██▉       | 24004/82783 [18:06<57:55, 16.91it/s]  

average loss: 0.007599920965731144


 29%|██▉       | 24103/82783 [18:10<1:00:51, 16.07it/s]

average loss: 0.009896881878376007


 29%|██▉       | 24204/82783 [18:15<55:41, 17.53it/s]  

average loss: 0.021511636674404144


 29%|██▉       | 24303/82783 [18:20<1:01:18, 15.90it/s]

average loss: 0.024077877402305603


 29%|██▉       | 24404/82783 [18:24<55:37, 17.49it/s]  

average loss: 0.020400777459144592


 30%|██▉       | 24503/82783 [18:29<59:30, 16.32it/s]

average loss: 0.011704175733029842


 30%|██▉       | 24605/82783 [18:34<55:27, 17.48it/s]  

average loss: 0.013821905478835106


 30%|██▉       | 24704/82783 [18:38<53:47, 18.00it/s]  

average loss: 0.03229988366365433


 30%|██▉       | 24803/82783 [18:43<1:01:02, 15.83it/s]

average loss: 0.014457996934652328


 30%|███       | 24905/82783 [18:48<54:31, 17.69it/s]  

average loss: 0.013252737931907177


 30%|███       | 25004/82783 [18:52<54:25, 17.69it/s]

average loss: 0.02072364278137684


 30%|███       | 25103/82783 [18:57<1:00:03, 16.01it/s]

average loss: 0.013016283512115479


 30%|███       | 25205/82783 [19:01<55:50, 17.18it/s]  

average loss: 0.019099874421954155


 31%|███       | 25304/82783 [19:06<55:23, 17.30it/s]  

average loss: 0.018446672707796097


 31%|███       | 25403/82783 [19:11<59:57, 15.95it/s]

average loss: 0.00486006448045373


 31%|███       | 25504/82783 [19:15<55:46, 17.12it/s]  

average loss: 0.014936663210391998


 31%|███       | 25603/82783 [19:20<1:00:23, 15.78it/s]

average loss: 0.005467369221150875


 31%|███       | 25705/82783 [19:25<54:22, 17.50it/s]  

average loss: 0.003978065215051174


 31%|███       | 25804/82783 [19:29<52:12, 18.19it/s]

average loss: 0.023168951272964478


 31%|███▏      | 25903/82783 [19:34<1:02:05, 15.27it/s]

average loss: 0.022209983319044113


 31%|███▏      | 26005/82783 [19:39<53:57, 17.54it/s]  

average loss: 0.010500632226467133


 32%|███▏      | 26104/82783 [19:43<53:39, 17.60it/s]  

average loss: 0.023537620902061462


 32%|███▏      | 26203/82783 [19:48<59:15, 15.91it/s]

average loss: 0.003048658836632967


 32%|███▏      | 26305/82783 [19:52<53:55, 17.45it/s]

average loss: 0.015965407714247704


 32%|███▏      | 26404/82783 [19:57<53:14, 17.65it/s]

average loss: 0.016603877767920494


 32%|███▏      | 26503/82783 [20:02<59:59, 15.64it/s]

average loss: 0.015692342072725296


 32%|███▏      | 26605/82783 [20:06<53:08, 17.62it/s]

average loss: 0.009038480930030346


 32%|███▏      | 26704/82783 [20:11<53:13, 17.56it/s]  

average loss: 0.011059543117880821


 32%|███▏      | 26803/82783 [20:15<58:25, 15.97it/s]

average loss: 0.0066578686237335205


 33%|███▎      | 26905/82783 [20:20<54:14, 17.17it/s]  

average loss: 0.01618168316781521


 33%|███▎      | 27004/82783 [20:25<53:01, 17.53it/s]

average loss: 0.017293978482484818


 33%|███▎      | 27103/82783 [20:29<57:26, 16.16it/s]

average loss: 0.012173516675829887


 33%|███▎      | 27205/82783 [20:34<52:57, 17.49it/s]

average loss: 0.013023456558585167


 33%|███▎      | 27304/82783 [20:39<53:39, 17.23it/s]

average loss: 0.006077542435377836


 33%|███▎      | 27403/82783 [20:43<56:28, 16.34it/s]

average loss: 0.00884760357439518


 33%|███▎      | 27505/82783 [20:48<53:30, 17.22it/s]

average loss: 0.00932370312511921


 33%|███▎      | 27604/82783 [20:53<51:38, 17.81it/s]

average loss: 0.013391342014074326


 33%|███▎      | 27703/82783 [20:57<56:52, 16.14it/s]

average loss: 0.01728433556854725


 34%|███▎      | 27805/82783 [21:02<52:28, 17.46it/s]

average loss: 0.00888452585786581


 34%|███▎      | 27904/82783 [21:06<52:16, 17.50it/s]

average loss: 0.016207121312618256


 34%|███▍      | 28003/82783 [21:11<56:05, 16.28it/s]

average loss: 0.009299174882471561


 34%|███▍      | 28105/82783 [21:16<52:18, 17.42it/s]

average loss: 0.007129488047212362


 34%|███▍      | 28204/82783 [21:20<51:59, 17.50it/s]

average loss: 0.00963403657078743


 34%|███▍      | 28303/82783 [21:25<55:03, 16.49it/s]

average loss: 0.0057091000489890575


 34%|███▍      | 28404/82783 [21:30<54:49, 16.53it/s]

average loss: 0.021495889872312546


 34%|███▍      | 28503/82783 [21:34<56:26, 16.03it/s]

average loss: 0.03501821309328079


 35%|███▍      | 28605/82783 [21:39<51:44, 17.45it/s]

average loss: 0.031001625582575798


 35%|███▍      | 28704/82783 [21:44<51:25, 17.53it/s]

average loss: 0.026885878294706345


 35%|███▍      | 28803/82783 [21:48<56:01, 16.06it/s]

average loss: 0.021311655640602112


 35%|███▍      | 28905/82783 [21:53<50:40, 17.72it/s]

average loss: 0.008625312708318233


 35%|███▌      | 29004/82783 [21:57<51:42, 17.33it/s]

average loss: 0.008769193664193153


 35%|███▌      | 29103/82783 [22:02<57:06, 15.67it/s]

average loss: 0.025725528597831726


 35%|███▌      | 29205/82783 [22:07<51:12, 17.44it/s]

average loss: 0.013520030304789543


 35%|███▌      | 29304/82783 [22:11<51:43, 17.23it/s]

average loss: 0.00795706920325756


 36%|███▌      | 29403/82783 [22:16<56:40, 15.70it/s]

average loss: 0.017837930470705032


 36%|███▌      | 29505/82783 [22:21<50:20, 17.64it/s]

average loss: 0.00818687491118908


 36%|███▌      | 29604/82783 [22:25<49:01, 18.08it/s]

average loss: 0.010594196617603302


 36%|███▌      | 29703/82783 [22:30<56:29, 15.66it/s]

average loss: 0.008948600850999355


 36%|███▌      | 29805/82783 [22:34<49:42, 17.76it/s]

average loss: 0.005568355321884155


 36%|███▌      | 29903/82783 [22:39<56:25, 15.62it/s]

average loss: 0.029718702659010887


 36%|███▌      | 30005/82783 [22:44<49:58, 17.60it/s]

average loss: 0.0039277458563447


 36%|███▋      | 30104/82783 [22:48<49:31, 17.73it/s]

average loss: 0.0072760870680212975


 36%|███▋      | 30203/82783 [22:53<55:10, 15.88it/s]

average loss: 0.007629533763974905


 37%|███▋      | 30305/82783 [22:58<49:46, 17.57it/s]

average loss: 0.007881829515099525


 37%|███▋      | 30404/82783 [23:02<49:51, 17.51it/s]

average loss: 0.001668774290010333


 37%|███▋      | 30503/82783 [23:07<53:58, 16.14it/s]

average loss: 0.02007124572992325


 37%|███▋      | 30605/82783 [23:12<50:54, 17.08it/s]

average loss: 0.009409507736563683


 37%|███▋      | 30704/82783 [23:16<49:55, 17.39it/s]

average loss: 0.009657880291342735


 37%|███▋      | 30804/82783 [23:21<50:38, 17.11it/s]

average loss: 0.009022757411003113


 37%|███▋      | 30903/82783 [23:25<54:03, 16.00it/s]

average loss: 0.013616800308227539


 37%|███▋      | 31005/82783 [23:30<51:13, 16.85it/s]

average loss: 0.011503135785460472


 38%|███▊      | 31104/82783 [23:35<49:23, 17.44it/s]

average loss: 0.007037631701678038


 38%|███▊      | 31205/82783 [23:39<48:04, 17.88it/s]

average loss: 0.023862993344664574


 38%|███▊      | 31304/82783 [23:44<49:34, 17.31it/s]

average loss: 0.03309648483991623


 38%|███▊      | 31403/82783 [23:49<54:01, 15.85it/s]

average loss: 0.023623300716280937


 38%|███▊      | 31505/82783 [23:53<47:49, 17.87it/s]

average loss: 0.012948352843523026


 38%|███▊      | 31604/82783 [23:58<49:38, 17.18it/s]

average loss: 0.030523907393217087


 38%|███▊      | 31705/82783 [24:03<49:59, 17.03it/s]

average loss: 0.006383964326232672


 38%|███▊      | 31804/82783 [24:07<48:25, 17.54it/s]

average loss: 0.026850737631320953


 39%|███▊      | 31903/82783 [24:12<53:22, 15.89it/s]

average loss: 0.013464200310409069


 39%|███▊      | 32005/82783 [24:16<49:15, 17.18it/s]

average loss: 0.02909991145133972


 39%|███▉      | 32104/82783 [24:21<47:16, 17.87it/s]

average loss: 0.016619110479950905


 39%|███▉      | 32203/82783 [24:26<53:19, 15.81it/s]

average loss: 0.02802152931690216


 39%|███▉      | 32305/82783 [24:30<48:33, 17.33it/s]

average loss: 0.01436326839029789


 39%|███▉      | 32404/82783 [24:35<48:35, 17.28it/s]

average loss: 0.011706044897437096


 39%|███▉      | 32503/82783 [24:40<53:22, 15.70it/s]

average loss: 0.008038364350795746


 39%|███▉      | 32605/82783 [24:44<48:55, 17.10it/s]

average loss: 0.03202849626541138


 40%|███▉      | 32704/82783 [24:49<47:00, 17.76it/s]

average loss: 0.0158379003405571


 40%|███▉      | 32803/82783 [24:54<50:41, 16.43it/s]

average loss: 0.031557418406009674


 40%|███▉      | 32905/82783 [24:58<47:44, 17.41it/s]

average loss: 0.006740821525454521


 40%|███▉      | 33004/82783 [25:03<47:27, 17.48it/s]

average loss: 0.010678990744054317


 40%|███▉      | 33103/82783 [25:07<51:10, 16.18it/s]

average loss: 0.024964651092886925


 40%|████      | 33205/82783 [25:12<46:54, 17.62it/s]

average loss: 0.011260702274739742


 40%|████      | 33304/82783 [25:17<47:11, 17.47it/s]

average loss: 0.023430615663528442


 40%|████      | 33405/82783 [25:21<47:46, 17.22it/s]

average loss: 0.008734113536775112


 40%|████      | 33504/82783 [25:26<48:17, 17.01it/s]

average loss: 0.02441389486193657


 41%|████      | 33603/82783 [25:31<51:12, 16.01it/s]

average loss: 0.006702845450490713


 41%|████      | 33705/82783 [25:35<46:57, 17.42it/s]

average loss: 0.04480388015508652


 41%|████      | 33804/82783 [25:40<47:02, 17.35it/s]

average loss: 0.010667165741324425


 41%|████      | 33903/82783 [25:44<51:10, 15.92it/s]

average loss: 0.011368087492883205


 41%|████      | 34005/82783 [25:49<47:00, 17.29it/s]

average loss: 0.01428142748773098


 41%|████      | 34104/82783 [25:54<46:57, 17.28it/s]

average loss: 0.051917046308517456


 41%|████▏     | 34203/82783 [25:58<50:55, 15.90it/s]

average loss: 0.026493055745959282


 41%|████▏     | 34305/82783 [26:03<45:21, 17.81it/s]

average loss: 0.012246919795870781


 42%|████▏     | 34404/82783 [26:08<47:58, 16.81it/s]

average loss: 0.016330838203430176


 42%|████▏     | 34503/82783 [26:12<50:01, 16.09it/s]

average loss: 0.006994161754846573


 42%|████▏     | 34605/82783 [26:17<45:38, 17.59it/s]

average loss: 0.0254824236035347


 42%|████▏     | 34704/82783 [26:21<46:26, 17.25it/s]

average loss: 0.009618008509278297


 42%|████▏     | 34803/82783 [26:26<51:31, 15.52it/s]

average loss: 0.00871985126286745


 42%|████▏     | 34905/82783 [26:31<45:37, 17.49it/s]

average loss: 0.024995502084493637


 42%|████▏     | 35003/82783 [26:35<49:47, 16.00it/s]

average loss: 0.012016965076327324


 42%|████▏     | 35105/82783 [26:40<45:11, 17.59it/s]

average loss: 0.013020428828895092


 43%|████▎     | 35204/82783 [26:45<44:55, 17.65it/s]

average loss: 0.009928221814334393


 43%|████▎     | 35303/82783 [26:49<49:41, 15.93it/s]

average loss: 0.0065333908423781395


 43%|████▎     | 35405/82783 [26:54<45:01, 17.54it/s]

average loss: 0.018206171691417694


 43%|████▎     | 35503/82783 [26:59<49:52, 15.80it/s]

average loss: 0.026284154504537582


 43%|████▎     | 35604/82783 [27:03<46:01, 17.08it/s]

average loss: 0.00249488465487957


 43%|████▎     | 35703/82783 [27:08<49:43, 15.78it/s]

average loss: 0.005685566924512386


 43%|████▎     | 35805/82783 [27:13<45:14, 17.31it/s]

average loss: 0.004758074879646301


 43%|████▎     | 35904/82783 [27:17<46:19, 16.86it/s]

average loss: 0.019413329660892487


 43%|████▎     | 36003/82783 [27:22<49:46, 15.66it/s]

average loss: 0.021808287128806114


 44%|████▎     | 36105/82783 [27:26<44:28, 17.49it/s]

average loss: 0.007820414379239082


 44%|████▎     | 36204/82783 [27:31<43:59, 17.65it/s]

average loss: 0.011542409658432007


 44%|████▍     | 36303/82783 [27:36<48:39, 15.92it/s]

average loss: 0.012194654904305935


 44%|████▍     | 36404/82783 [27:40<45:32, 16.97it/s]

average loss: 0.005706227850168943


 44%|████▍     | 36503/82783 [27:45<48:44, 15.83it/s]

average loss: 0.00813221000134945


 44%|████▍     | 36605/82783 [27:50<44:19, 17.36it/s]

average loss: 0.01472613587975502


 44%|████▍     | 36704/82783 [27:54<45:05, 17.03it/s]

average loss: 0.009349210187792778


 44%|████▍     | 36803/82783 [27:59<47:04, 16.28it/s]

average loss: 0.014090071432292461


 45%|████▍     | 36905/82783 [28:04<43:37, 17.53it/s]

average loss: 0.028876570984721184


 45%|████▍     | 37004/82783 [28:08<43:25, 17.57it/s]

average loss: 0.018982872366905212


 45%|████▍     | 37103/82783 [28:13<47:39, 15.98it/s]

average loss: 0.013478132896125317


 45%|████▍     | 37205/82783 [28:17<43:31, 17.45it/s]

average loss: 0.016004640609025955


 45%|████▌     | 37304/82783 [28:22<43:11, 17.55it/s]

average loss: 0.008355138823390007


 45%|████▌     | 37403/82783 [28:27<47:54, 15.79it/s]

average loss: 0.018700849264860153


 45%|████▌     | 37505/82783 [28:31<44:06, 17.11it/s]

average loss: 0.01294534932821989


 45%|████▌     | 37604/82783 [28:36<42:33, 17.69it/s]

average loss: 0.01766066625714302


 46%|████▌     | 37703/82783 [28:40<47:10, 15.93it/s]

average loss: 0.012164230458438396


 46%|████▌     | 37805/82783 [28:45<43:08, 17.38it/s]

average loss: 0.019713036715984344


 46%|████▌     | 37904/82783 [28:50<43:11, 17.32it/s]

average loss: 0.005025561433285475


 46%|████▌     | 38003/82783 [28:54<46:37, 16.01it/s]

average loss: 0.014071216806769371


 46%|████▌     | 38104/82783 [28:59<45:01, 16.54it/s]

average loss: 0.01737508922815323


 46%|████▌     | 38203/82783 [29:04<46:03, 16.13it/s]

average loss: 0.013665169477462769


 46%|████▋     | 38305/82783 [29:08<43:00, 17.24it/s]

average loss: 0.00915562454611063


 46%|████▋     | 38404/82783 [29:13<41:14, 17.94it/s]

average loss: 0.015087053179740906


 47%|████▋     | 38503/82783 [29:17<46:30, 15.87it/s]

average loss: 0.022368963807821274


 47%|████▋     | 38604/82783 [29:22<42:34, 17.29it/s]

average loss: 0.013242046348750591


 47%|████▋     | 38703/82783 [29:27<46:21, 15.85it/s]

average loss: 0.00859297439455986


 47%|████▋     | 38805/82783 [29:32<43:00, 17.04it/s]

average loss: 0.005747186951339245


 47%|████▋     | 38904/82783 [29:36<41:01, 17.83it/s]

average loss: 0.019315343350172043


 47%|████▋     | 39003/82783 [29:41<46:08, 15.81it/s]

average loss: 0.009555796161293983


 47%|████▋     | 39104/82783 [29:45<43:14, 16.84it/s]

average loss: 0.009111655876040459


 47%|████▋     | 39203/82783 [29:50<45:01, 16.13it/s]

average loss: 0.010268563404679298


 47%|████▋     | 39305/82783 [29:55<41:32, 17.45it/s]

average loss: 0.011017711833119392


 48%|████▊     | 39404/82783 [29:59<41:24, 17.46it/s]

average loss: 0.012703314423561096


 48%|████▊     | 39503/82783 [30:04<46:21, 15.56it/s]

average loss: 0.01090945489704609


 48%|████▊     | 39605/82783 [30:09<41:36, 17.30it/s]

average loss: 0.008221687749028206


 48%|████▊     | 39704/82783 [30:13<41:32, 17.29it/s]

average loss: 0.022394966334104538


 48%|████▊     | 39803/82783 [30:18<45:50, 15.63it/s]

average loss: 0.012580543756484985


 48%|████▊     | 39905/82783 [30:23<41:18, 17.30it/s]

average loss: 0.01969033107161522


 48%|████▊     | 40004/82783 [30:27<40:58, 17.40it/s]

average loss: 0.007708635181188583


 48%|████▊     | 40103/82783 [30:32<46:04, 15.44it/s]

average loss: 0.015331752598285675


 49%|████▊     | 40205/82783 [30:36<40:12, 17.65it/s]

average loss: 0.006560568697750568


 49%|████▊     | 40304/82783 [30:41<41:12, 17.18it/s]

average loss: 0.0062498897314071655


 49%|████▉     | 40403/82783 [30:46<45:04, 15.67it/s]

average loss: 0.008589230477809906


 49%|████▉     | 40505/82783 [30:50<39:39, 17.77it/s]

average loss: 0.01964416168630123


 49%|████▉     | 40604/82783 [30:55<40:14, 17.47it/s]

average loss: 0.016956349834799767


 49%|████▉     | 40703/82783 [31:00<45:08, 15.54it/s]

average loss: 0.010330485180020332


 49%|████▉     | 40805/82783 [31:04<38:56, 17.96it/s]

average loss: 0.025019392371177673


 49%|████▉     | 40904/82783 [31:09<39:59, 17.46it/s]

average loss: 0.003800090402364731


 50%|████▉     | 41003/82783 [31:13<43:30, 16.01it/s]

average loss: 0.0040571605786681175


 50%|████▉     | 41105/82783 [31:18<39:23, 17.64it/s]

average loss: 0.012112977914512157


 50%|████▉     | 41204/82783 [31:23<40:02, 17.30it/s]

average loss: 0.005647359415888786


 50%|████▉     | 41305/82783 [31:27<39:26, 17.52it/s]

average loss: 0.019390640780329704


 50%|█████     | 41404/82783 [31:32<38:42, 17.82it/s]

average loss: 0.012214519083499908


 50%|█████     | 41503/82783 [31:37<43:12, 15.92it/s]

average loss: 0.004275203682482243


 50%|█████     | 41605/82783 [31:41<38:51, 17.66it/s]

average loss: 0.02532707154750824


 50%|█████     | 41704/82783 [31:46<38:57, 17.58it/s]

average loss: 0.010489888489246368


 50%|█████     | 41803/82783 [31:50<42:44, 15.98it/s]

average loss: 0.008676501922309399


 51%|█████     | 41905/82783 [31:55<39:10, 17.39it/s]

average loss: 0.00918603129684925


 51%|█████     | 42004/82783 [32:00<39:44, 17.10it/s]

average loss: 0.006659429520368576


 51%|█████     | 42105/82783 [32:05<38:26, 17.64it/s]

average loss: 0.031023874878883362


 51%|█████     | 42204/82783 [32:09<38:01, 17.79it/s]

average loss: 0.014030059799551964


 51%|█████     | 42303/82783 [32:14<42:27, 15.89it/s]

average loss: 0.004667173605412245


 51%|█████     | 42405/82783 [32:18<38:21, 17.54it/s]

average loss: 0.010758854448795319


 51%|█████▏    | 42504/82783 [32:23<38:08, 17.60it/s]

average loss: 0.02200758270919323


 51%|█████▏    | 42603/82783 [32:28<42:41, 15.69it/s]

average loss: 0.019519973546266556


 52%|█████▏    | 42704/82783 [32:32<38:14, 17.46it/s]

average loss: 0.01663704216480255


 52%|█████▏    | 42803/82783 [32:37<40:53, 16.29it/s]

average loss: 0.007311288267374039


 52%|█████▏    | 42905/82783 [32:42<37:24, 17.77it/s]

average loss: 0.008504537865519524


 52%|█████▏    | 43004/82783 [32:46<38:23, 17.27it/s]

average loss: 0.010184640064835548


 52%|█████▏    | 43103/82783 [32:51<42:04, 15.72it/s]

average loss: 0.012340225279331207


 52%|█████▏    | 43205/82783 [32:56<38:00, 17.35it/s]

average loss: 0.009931991808116436


 52%|█████▏    | 43304/82783 [33:00<39:39, 16.59it/s]

average loss: 0.0024508058559149504


 52%|█████▏    | 43403/82783 [33:05<41:35, 15.78it/s]

average loss: 0.016048619523644447


 53%|█████▎    | 43505/82783 [33:10<38:19, 17.08it/s]

average loss: 0.009792334400117397


 53%|█████▎    | 43604/82783 [33:14<37:05, 17.61it/s]

average loss: 0.007240676786750555


 53%|█████▎    | 43703/82783 [33:19<42:31, 15.32it/s]

average loss: 0.005837851669639349


 53%|█████▎    | 43804/82783 [33:24<36:35, 17.75it/s]

average loss: 0.022627048194408417


 53%|█████▎    | 43903/82783 [33:28<40:47, 15.89it/s]

average loss: 0.0329953134059906


 53%|█████▎    | 44005/82783 [33:33<37:00, 17.46it/s]

average loss: 0.023566966876387596


 53%|█████▎    | 44104/82783 [33:37<37:57, 16.98it/s]

average loss: 0.007328640203922987


 53%|█████▎    | 44203/82783 [33:42<40:46, 15.77it/s]

average loss: 0.013238275423645973


 54%|█████▎    | 44305/82783 [33:47<36:28, 17.58it/s]

average loss: 0.01101231575012207


 54%|█████▎    | 44404/82783 [33:51<36:24, 17.57it/s]

average loss: 0.009678322821855545


 54%|█████▍    | 44503/82783 [33:56<39:53, 15.99it/s]

average loss: 0.009926514700055122


 54%|█████▍    | 44605/82783 [34:01<36:39, 17.36it/s]

average loss: 0.006326674483716488


 54%|█████▍    | 44704/82783 [34:05<35:55, 17.67it/s]

average loss: 0.019278496503829956


 54%|█████▍    | 44803/82783 [34:10<39:30, 16.02it/s]

average loss: 0.007930343970656395


 54%|█████▍    | 44905/82783 [34:14<36:08, 17.47it/s]

average loss: 0.005322865676134825


 54%|█████▍    | 45004/82783 [34:19<35:48, 17.58it/s]

average loss: 0.011708086356520653


 54%|█████▍    | 45103/82783 [34:24<38:52, 16.16it/s]

average loss: 0.019686970859766006


 55%|█████▍    | 45205/82783 [34:28<36:16, 17.26it/s]

average loss: 0.008689730428159237


 55%|█████▍    | 45304/82783 [34:33<35:04, 17.81it/s]

average loss: 0.006138961762189865


 55%|█████▍    | 45403/82783 [34:37<39:04, 15.94it/s]

average loss: 0.02447531744837761


 55%|█████▍    | 45505/82783 [34:42<35:31, 17.49it/s]

average loss: 0.009056062437593937


 55%|█████▌    | 45603/82783 [34:47<40:05, 15.45it/s]

average loss: 0.015109732747077942


 55%|█████▌    | 45704/82783 [34:51<37:10, 16.63it/s]

average loss: 0.010731153190135956


 55%|█████▌    | 45803/82783 [34:56<38:49, 15.87it/s]

average loss: 0.020079245790839195


 55%|█████▌    | 45905/82783 [35:01<35:56, 17.10it/s]

average loss: 0.006054850295186043


 56%|█████▌    | 46004/82783 [35:05<34:16, 17.88it/s]

average loss: 0.004885422997176647


 56%|█████▌    | 46103/82783 [35:10<38:38, 15.82it/s]

average loss: 0.0157717727124691


 56%|█████▌    | 46205/82783 [35:15<35:15, 17.29it/s]

average loss: 0.016925644129514694


 56%|█████▌    | 46304/82783 [35:19<34:26, 17.65it/s]

average loss: 0.016857963055372238


 56%|█████▌    | 46403/82783 [35:24<38:39, 15.68it/s]

average loss: 0.0068592494353652


 56%|█████▌    | 46505/82783 [35:29<34:21, 17.60it/s]

average loss: 0.013032637536525726


 56%|█████▋    | 46604/82783 [35:33<34:00, 17.73it/s]

average loss: 0.01611010544002056


 56%|█████▋    | 46703/82783 [35:38<37:18, 16.12it/s]

average loss: 0.03310629725456238


 57%|█████▋    | 46805/82783 [35:43<33:02, 18.15it/s]

average loss: 0.00341587932780385


 57%|█████▋    | 46904/82783 [35:47<34:13, 17.48it/s]

average loss: 0.007678893860429525


 57%|█████▋    | 47003/82783 [35:52<38:09, 15.63it/s]

average loss: 0.024752378463745117


 57%|█████▋    | 47105/82783 [35:56<34:57, 17.01it/s]

average loss: 0.01714867725968361


 57%|█████▋    | 47204/82783 [36:01<33:19, 17.80it/s]

average loss: 0.022776726633310318


 57%|█████▋    | 47303/82783 [36:06<36:50, 16.05it/s]

average loss: 0.005987327545881271


 57%|█████▋    | 47405/82783 [36:10<35:00, 16.84it/s]

average loss: 0.00960596650838852


 57%|█████▋    | 47504/82783 [36:15<33:15, 17.68it/s]

average loss: 0.014535188674926758


 58%|█████▊    | 47603/82783 [36:19<37:12, 15.76it/s]

average loss: 0.02081310749053955


 58%|█████▊    | 47705/82783 [36:24<33:08, 17.64it/s]

average loss: 0.008660627529025078


 58%|█████▊    | 47804/82783 [36:29<32:38, 17.86it/s]

average loss: 0.014465946704149246


 58%|█████▊    | 47903/82783 [36:33<36:15, 16.03it/s]

average loss: 0.027440588921308517


 58%|█████▊    | 48005/82783 [36:38<32:55, 17.60it/s]

average loss: 0.0035502330865710974


 58%|█████▊    | 48103/82783 [36:43<36:27, 15.85it/s]

average loss: 0.016359765082597733


 58%|█████▊    | 48205/82783 [36:47<33:25, 17.24it/s]

average loss: 0.008472857996821404


 58%|█████▊    | 48304/82783 [36:52<33:05, 17.36it/s]

average loss: 0.01258432399481535


 58%|█████▊    | 48403/82783 [36:57<36:18, 15.78it/s]

average loss: 0.030075063928961754


 59%|█████▊    | 48505/82783 [37:01<31:40, 18.04it/s]

average loss: 0.003562053432688117


 59%|█████▊    | 48604/82783 [37:06<32:14, 17.67it/s]

average loss: 0.008091232739388943


 59%|█████▉    | 48703/82783 [37:10<36:49, 15.43it/s]

average loss: 0.017539113759994507


 59%|█████▉    | 48805/82783 [37:15<32:00, 17.69it/s]

average loss: 0.0174812413752079


 59%|█████▉    | 48904/82783 [37:20<31:44, 17.79it/s]

average loss: 0.003291552420705557


 59%|█████▉    | 49003/82783 [37:24<35:16, 15.96it/s]

average loss: 0.026140861213207245


 59%|█████▉    | 49105/82783 [37:29<31:33, 17.79it/s]

average loss: 0.017514947801828384


 59%|█████▉    | 49204/82783 [37:34<31:51, 17.57it/s]

average loss: 0.02173621580004692


 60%|█████▉    | 49303/82783 [37:38<35:37, 15.67it/s]

average loss: 0.00843015592545271


 60%|█████▉    | 49405/82783 [37:43<31:18, 17.77it/s]

average loss: 0.01773492433130741


 60%|█████▉    | 49504/82783 [37:47<31:58, 17.35it/s]

average loss: 0.007818687707185745


 60%|█████▉    | 49603/82783 [37:52<34:28, 16.04it/s]

average loss: 0.02802341803908348


 60%|██████    | 49705/82783 [37:57<32:44, 16.84it/s]

average loss: 0.03203988075256348


 60%|██████    | 49804/82783 [38:01<32:12, 17.06it/s]

average loss: 0.01897699199616909


 60%|██████    | 49903/82783 [38:06<33:38, 16.29it/s]

average loss: 0.023635467514395714


 60%|██████    | 50005/82783 [38:11<31:20, 17.43it/s]

average loss: 0.027380067855119705


 61%|██████    | 50104/82783 [38:15<31:06, 17.50it/s]

average loss: 0.025941312313079834


 61%|██████    | 50203/82783 [38:20<33:13, 16.34it/s]

average loss: 0.022089986130595207


 61%|██████    | 50305/82783 [38:24<29:54, 18.10it/s]

average loss: 0.014207085594534874


 61%|██████    | 50404/82783 [38:29<31:43, 17.01it/s]

average loss: 0.006924903951585293


 61%|██████    | 50503/82783 [38:34<33:25, 16.10it/s]

average loss: 0.007696148939430714


 61%|██████    | 50605/82783 [38:38<30:45, 17.44it/s]

average loss: 0.009267948567867279


 61%|██████    | 50704/82783 [38:43<30:09, 17.73it/s]

average loss: 0.006228956393897533


 61%|██████▏   | 50803/82783 [38:47<33:39, 15.84it/s]

average loss: 0.0069099063985049725


 61%|██████▏   | 50905/82783 [38:52<29:21, 18.10it/s]

average loss: 0.007456253282725811


 62%|██████▏   | 51004/82783 [38:57<30:17, 17.49it/s]

average loss: 0.007345890626311302


 62%|██████▏   | 51103/82783 [39:01<32:07, 16.43it/s]

average loss: 0.011464382521808147


 62%|██████▏   | 51205/82783 [39:06<30:03, 17.51it/s]

average loss: 0.007625700440257788


 62%|██████▏   | 51304/82783 [39:11<30:11, 17.38it/s]

average loss: 0.005319227930158377


 62%|██████▏   | 51403/82783 [39:15<32:50, 15.93it/s]

average loss: 0.033302806317806244


 62%|██████▏   | 51505/82783 [39:20<29:57, 17.40it/s]

average loss: 0.006926952861249447


 62%|██████▏   | 51604/82783 [39:24<29:01, 17.91it/s]

average loss: 0.0070049758069217205


 62%|██████▏   | 51703/82783 [39:29<32:35, 15.90it/s]

average loss: 0.01123923808336258


 63%|██████▎   | 51805/82783 [39:34<29:48, 17.32it/s]

average loss: 0.014347955584526062


 63%|██████▎   | 51904/82783 [39:38<29:01, 17.73it/s]

average loss: 0.019722647964954376


 63%|██████▎   | 52003/82783 [39:43<32:31, 15.77it/s]

average loss: 0.028208721429109573


 63%|██████▎   | 52103/82783 [39:48<30:09, 16.95it/s]

average loss: 0.007012469228357077


 63%|██████▎   | 52205/82783 [39:52<30:02, 16.96it/s]

average loss: 0.013743794523179531


 63%|██████▎   | 52304/82783 [39:57<28:31, 17.81it/s]

average loss: 0.021377963945269585


 63%|██████▎   | 52403/82783 [40:01<31:12, 16.23it/s]

average loss: 0.01778903603553772


 63%|██████▎   | 52505/82783 [40:06<29:07, 17.33it/s]

average loss: 0.014499031007289886


 64%|██████▎   | 52604/82783 [40:11<29:16, 17.18it/s]

average loss: 0.011204078793525696


 64%|██████▎   | 52703/82783 [40:15<31:43, 15.80it/s]

average loss: 0.01041010208427906


 64%|██████▍   | 52805/82783 [40:20<28:11, 17.73it/s]

average loss: 0.01824595406651497


 64%|██████▍   | 52904/82783 [40:25<29:00, 17.17it/s]

average loss: 0.015885833650827408


 64%|██████▍   | 53003/82783 [40:29<30:55, 16.05it/s]

average loss: 0.011358266696333885


 64%|██████▍   | 53105/82783 [40:34<28:08, 17.57it/s]

average loss: 0.005471613258123398


 64%|██████▍   | 53204/82783 [40:38<28:03, 17.57it/s]

average loss: 0.014186440035700798


 64%|██████▍   | 53303/82783 [40:43<30:24, 16.16it/s]

average loss: 0.008845662698149681


 65%|██████▍   | 53405/82783 [40:48<27:40, 17.69it/s]

average loss: 0.036851055920124054


 65%|██████▍   | 53504/82783 [40:52<27:11, 17.95it/s]

average loss: 0.023631218820810318


 65%|██████▍   | 53603/82783 [40:57<30:09, 16.13it/s]

average loss: 0.01876877248287201


 65%|██████▍   | 53705/82783 [41:02<27:38, 17.53it/s]

average loss: 0.018585383892059326


 65%|██████▍   | 53804/82783 [41:06<27:07, 17.81it/s]

average loss: 0.008976083248853683


 65%|██████▌   | 53903/82783 [41:11<29:36, 16.25it/s]

average loss: 0.008953729644417763


 65%|██████▌   | 54005/82783 [41:15<27:10, 17.65it/s]

average loss: 0.01234050840139389


 65%|██████▌   | 54104/82783 [41:20<27:35, 17.32it/s]

average loss: 0.005827001295983791


 65%|██████▌   | 54203/82783 [41:25<30:03, 15.85it/s]

average loss: 0.011423030868172646


 66%|██████▌   | 54305/82783 [41:29<26:57, 17.61it/s]

average loss: 0.010475854389369488


 66%|██████▌   | 54404/82783 [41:34<26:53, 17.59it/s]

average loss: 0.007903071120381355


 66%|██████▌   | 54503/82783 [41:38<29:03, 16.22it/s]

average loss: 0.023677479475736618


 66%|██████▌   | 54605/82783 [41:43<27:10, 17.28it/s]

average loss: 0.020535767078399658


 66%|██████▌   | 54704/82783 [41:48<26:33, 17.62it/s]

average loss: 0.009030432440340519


 66%|██████▌   | 54803/82783 [41:52<29:26, 15.84it/s]

average loss: 0.03513035178184509


 66%|██████▋   | 54905/82783 [41:57<27:34, 16.85it/s]

average loss: 0.012321792542934418


 66%|██████▋   | 55004/82783 [42:01<25:25, 18.21it/s]

average loss: 0.020295048132538795


 67%|██████▋   | 55103/82783 [42:06<28:28, 16.20it/s]

average loss: 0.0033566346392035484


 67%|██████▋   | 55205/82783 [42:11<26:12, 17.54it/s]

average loss: 0.0038333479315042496


 67%|██████▋   | 55304/82783 [42:15<26:33, 17.24it/s]

average loss: 0.005788800306618214


 67%|██████▋   | 55403/82783 [42:20<28:53, 15.79it/s]

average loss: 0.009310673922300339


 67%|██████▋   | 55505/82783 [42:25<25:24, 17.89it/s]

average loss: 0.01351176481693983


 67%|██████▋   | 55604/82783 [42:29<26:02, 17.39it/s]

average loss: 0.007608499377965927


 67%|██████▋   | 55705/82783 [42:34<27:36, 16.34it/s]

average loss: 0.007745531387627125


 67%|██████▋   | 55804/82783 [42:38<25:26, 17.67it/s]

average loss: 0.0049466039054095745


 68%|██████▊   | 55903/82783 [42:43<27:29, 16.29it/s]

average loss: 0.013795142993330956


 68%|██████▊   | 56005/82783 [42:48<25:44, 17.34it/s]

average loss: 0.005215439014136791


 68%|██████▊   | 56104/82783 [42:52<25:42, 17.29it/s]

average loss: 0.005200481042265892


 68%|██████▊   | 56203/82783 [42:57<27:26, 16.15it/s]

average loss: 0.019246947020292282


 68%|██████▊   | 56305/82783 [43:02<24:33, 17.96it/s]

average loss: 0.006744598504155874


 68%|██████▊   | 56404/82783 [43:06<25:08, 17.49it/s]

average loss: 0.005859720520675182


 68%|██████▊   | 56503/82783 [43:11<27:26, 15.96it/s]

average loss: 0.009944070130586624


 68%|██████▊   | 56605/82783 [43:15<25:09, 17.34it/s]

average loss: 0.00978977419435978


 68%|██████▊   | 56703/82783 [43:20<28:09, 15.44it/s]

average loss: 0.007898380979895592


 69%|██████▊   | 56805/82783 [43:25<24:36, 17.59it/s]

average loss: 0.006836016662418842


 69%|██████▊   | 56904/82783 [43:29<24:44, 17.43it/s]

average loss: 0.05308517813682556


 69%|██████▉   | 57003/82783 [43:34<26:56, 15.94it/s]

average loss: 0.012052072212100029


 69%|██████▉   | 57105/82783 [43:39<24:29, 17.48it/s]

average loss: 0.024209491908550262


 69%|██████▉   | 57204/82783 [43:43<23:43, 17.97it/s]

average loss: 0.011256581172347069


 69%|██████▉   | 57303/82783 [43:48<26:43, 15.89it/s]

average loss: 0.014232161454856396


 69%|██████▉   | 57405/82783 [43:53<24:07, 17.54it/s]

average loss: 0.02022199146449566


 69%|██████▉   | 57504/82783 [43:57<23:23, 18.01it/s]

average loss: 0.017490778118371964


 70%|██████▉   | 57603/82783 [44:02<26:31, 15.82it/s]

average loss: 0.0034901227336376905


 70%|██████▉   | 57705/82783 [44:07<25:15, 16.55it/s]

average loss: 0.014594295993447304


 70%|██████▉   | 57804/82783 [44:11<24:13, 17.18it/s]

average loss: 0.019146602600812912


 70%|██████▉   | 57903/82783 [44:16<25:46, 16.09it/s]

average loss: 0.003520750906318426


 70%|███████   | 58004/82783 [44:20<23:13, 17.79it/s]

average loss: 0.011172056198120117


 70%|███████   | 58103/82783 [44:25<25:46, 15.96it/s]

average loss: 0.0185233186930418


 70%|███████   | 58205/82783 [44:30<23:00, 17.81it/s]

average loss: 0.014171648770570755


 70%|███████   | 58304/82783 [44:34<23:35, 17.30it/s]

average loss: 0.009785406291484833


 71%|███████   | 58403/82783 [44:39<26:08, 15.55it/s]

average loss: 0.015679366886615753


 71%|███████   | 58505/82783 [44:44<23:06, 17.50it/s]

average loss: 0.007126647979021072


 71%|███████   | 58604/82783 [44:48<22:52, 17.62it/s]

average loss: 0.009244962595403194


 71%|███████   | 58703/82783 [44:53<25:05, 16.00it/s]

average loss: 0.012301558628678322


 71%|███████   | 58805/82783 [44:57<23:04, 17.32it/s]

average loss: 0.025092532858252525


 71%|███████   | 58904/82783 [45:02<22:20, 17.81it/s]

average loss: 0.013825558125972748


 71%|███████▏  | 59003/82783 [45:07<24:30, 16.17it/s]

average loss: 0.009403545409440994


 71%|███████▏  | 59105/82783 [45:11<22:39, 17.42it/s]

average loss: 0.00922374241054058


 72%|███████▏  | 59204/82783 [45:16<22:32, 17.43it/s]

average loss: 0.024621907621622086


 72%|███████▏  | 59303/82783 [45:20<24:32, 15.94it/s]

average loss: 0.004504359792917967


 72%|███████▏  | 59405/82783 [45:25<22:02, 17.68it/s]

average loss: 0.03416792303323746


 72%|███████▏  | 59504/82783 [45:30<22:14, 17.44it/s]

average loss: 0.004750202875584364


 72%|███████▏  | 59605/82783 [45:34<21:57, 17.59it/s]

average loss: 0.013659445568919182


 72%|███████▏  | 59704/82783 [45:39<22:25, 17.15it/s]

average loss: 0.024719832465052605


 72%|███████▏  | 59803/82783 [45:44<23:56, 16.00it/s]

average loss: 0.003130729775875807


 72%|███████▏  | 59905/82783 [45:48<21:27, 17.76it/s]

average loss: 0.028291817754507065


 72%|███████▏  | 60004/82783 [45:53<21:43, 17.48it/s]

average loss: 0.0030037211254239082


 73%|███████▎  | 60103/82783 [45:58<23:36, 16.02it/s]

average loss: 0.019068587571382523


 73%|███████▎  | 60205/82783 [46:02<21:42, 17.33it/s]

average loss: 0.01741548627614975


 73%|███████▎  | 60304/82783 [46:07<21:46, 17.20it/s]

average loss: 0.012349538505077362


 73%|███████▎  | 60403/82783 [46:12<24:25, 15.27it/s]

average loss: 0.007380572613328695


 73%|███████▎  | 60505/82783 [46:16<21:03, 17.63it/s]

average loss: 0.018424469977617264


 73%|███████▎  | 60604/82783 [46:21<20:39, 17.89it/s]

average loss: 0.023010119795799255


 73%|███████▎  | 60703/82783 [46:26<23:29, 15.67it/s]

average loss: 0.009359794668853283


 73%|███████▎  | 60805/82783 [46:30<21:30, 17.03it/s]

average loss: 0.019470637664198875


 74%|███████▎  | 60904/82783 [46:35<21:13, 17.18it/s]

average loss: 0.02650907263159752


 74%|███████▎  | 61003/82783 [46:39<23:13, 15.63it/s]

average loss: 0.003002288518473506


 74%|███████▍  | 61104/82783 [46:44<20:30, 17.61it/s]

average loss: 0.00833808071911335


 74%|███████▍  | 61203/82783 [46:49<22:24, 16.06it/s]

average loss: 0.011782861314713955


 74%|███████▍  | 61305/82783 [46:53<20:36, 17.36it/s]

average loss: 0.039283789694309235


 74%|███████▍  | 61404/82783 [46:58<20:15, 17.58it/s]

average loss: 0.0034319646656513214


 74%|███████▍  | 61503/82783 [47:03<22:11, 15.98it/s]

average loss: 0.014094952493906021


 74%|███████▍  | 61605/82783 [47:07<20:28, 17.24it/s]

average loss: 0.013217347674071789


 75%|███████▍  | 61704/82783 [47:12<20:43, 16.95it/s]

average loss: 0.013917389325797558


 75%|███████▍  | 61803/82783 [47:17<22:43, 15.39it/s]

average loss: 0.015331579372286797


 75%|███████▍  | 61905/82783 [47:21<20:35, 16.89it/s]

average loss: 0.005798558704555035


 75%|███████▍  | 62004/82783 [47:26<19:58, 17.33it/s]

average loss: 0.009339134208858013


 75%|███████▌  | 62103/82783 [47:31<22:28, 15.33it/s]

average loss: 0.02093491517007351


 75%|███████▌  | 62205/82783 [47:35<19:21, 17.72it/s]

average loss: 0.012176603078842163


 75%|███████▌  | 62304/82783 [47:40<19:31, 17.48it/s]

average loss: 0.014382466673851013


 75%|███████▌  | 62403/82783 [47:45<21:26, 15.85it/s]

average loss: 0.017660707235336304


 76%|███████▌  | 62505/82783 [47:49<19:27, 17.37it/s]

average loss: 0.01020500622689724


 76%|███████▌  | 62604/82783 [47:54<19:14, 17.47it/s]

average loss: 0.019372059032320976


 76%|███████▌  | 62703/82783 [47:59<21:19, 15.69it/s]

average loss: 0.021085090935230255


 76%|███████▌  | 62805/82783 [48:03<18:54, 17.60it/s]

average loss: 0.013263292610645294


 76%|███████▌  | 62904/82783 [48:08<19:13, 17.23it/s]

average loss: 0.019700290635228157


 76%|███████▌  | 63003/82783 [48:13<21:31, 15.32it/s]

average loss: 0.03213617950677872


 76%|███████▌  | 63105/82783 [48:17<18:49, 17.43it/s]

average loss: 0.024573182687163353


 76%|███████▋  | 63204/82783 [48:22<18:31, 17.62it/s]

average loss: 0.02032078616321087


 76%|███████▋  | 63303/82783 [48:27<20:24, 15.90it/s]

average loss: 0.013030989095568657


 77%|███████▋  | 63405/82783 [48:31<18:54, 17.09it/s]

average loss: 0.00952111929655075


 77%|███████▋  | 63504/82783 [48:36<18:22, 17.49it/s]

average loss: 0.00708008324727416


 77%|███████▋  | 63605/82783 [48:41<17:57, 17.80it/s]

average loss: 0.012764732353389263


 77%|███████▋  | 63704/82783 [48:45<18:17, 17.38it/s]

average loss: 0.014944418333470821


 77%|███████▋  | 63803/82783 [48:50<19:43, 16.04it/s]

average loss: 0.015122653916478157


 77%|███████▋  | 63905/82783 [48:55<18:16, 17.22it/s]

average loss: 0.0025570306461304426


 77%|███████▋  | 64004/82783 [48:59<18:06, 17.29it/s]

average loss: 0.005858229007571936


 77%|███████▋  | 64103/82783 [49:04<18:48, 16.55it/s]

average loss: 0.05839593708515167


 78%|███████▊  | 64205/82783 [49:09<17:46, 17.42it/s]

average loss: 0.018497899174690247


 78%|███████▊  | 64304/82783 [49:13<17:24, 17.68it/s]

average loss: 0.014910850673913956


 78%|███████▊  | 64403/82783 [49:18<19:08, 16.01it/s]

average loss: 0.0091353515163064


 78%|███████▊  | 64505/82783 [49:22<17:51, 17.05it/s]

average loss: 0.02706446498632431


 78%|███████▊  | 64604/82783 [49:27<17:01, 17.80it/s]

average loss: 0.004286559298634529


 78%|███████▊  | 64703/82783 [49:32<19:16, 15.63it/s]

average loss: 0.014961861073970795


 78%|███████▊  | 64804/82783 [49:36<17:06, 17.52it/s]

average loss: 0.0036076854448765516


 78%|███████▊  | 64903/82783 [49:41<18:46, 15.87it/s]

average loss: 0.020480390638113022


 79%|███████▊  | 65005/82783 [49:46<16:52, 17.56it/s]

average loss: 0.004687094129621983


 79%|███████▊  | 65104/82783 [49:50<16:52, 17.45it/s]

average loss: 0.003720274893566966


 79%|███████▉  | 65203/82783 [49:55<18:33, 15.80it/s]

average loss: 0.00449777115136385


 79%|███████▉  | 65305/82783 [50:00<16:30, 17.64it/s]

average loss: 0.009477101266384125


 79%|███████▉  | 65404/82783 [50:04<16:38, 17.40it/s]

average loss: 0.019532259553670883


 79%|███████▉  | 65503/82783 [50:09<18:16, 15.76it/s]

average loss: 0.020424334332346916


 79%|███████▉  | 65605/82783 [50:14<16:20, 17.52it/s]

average loss: 0.011202728375792503


 79%|███████▉  | 65704/82783 [50:18<16:37, 17.12it/s]

average loss: 0.01355819683521986


 79%|███████▉  | 65803/82783 [50:23<17:57, 15.77it/s]

average loss: 0.012647032737731934


 80%|███████▉  | 65905/82783 [50:28<16:45, 16.78it/s]

average loss: 0.012547033838927746


 80%|███████▉  | 66004/82783 [50:32<15:36, 17.91it/s]

average loss: 0.005640159826725721


 80%|███████▉  | 66103/82783 [50:37<17:15, 16.11it/s]

average loss: 0.006074447184801102


 80%|███████▉  | 66205/82783 [50:42<15:45, 17.53it/s]

average loss: 0.03602251410484314


 80%|████████  | 66304/82783 [50:46<16:05, 17.07it/s]

average loss: 0.010546203702688217


 80%|████████  | 66403/82783 [50:51<17:03, 16.01it/s]

average loss: 0.017524031922221184


 80%|████████  | 66504/82783 [50:55<15:10, 17.87it/s]

average loss: 0.005934457294642925


 80%|████████  | 66603/82783 [51:00<17:15, 15.62it/s]

average loss: 0.010030128993093967


 81%|████████  | 66705/82783 [51:05<15:07, 17.72it/s]

average loss: 0.005724902264773846


 81%|████████  | 66804/82783 [51:09<14:55, 17.84it/s]

average loss: 0.009233562275767326


 81%|████████  | 66903/82783 [51:14<16:40, 15.87it/s]

average loss: 0.008408747613430023


 81%|████████  | 67005/82783 [51:19<15:03, 17.46it/s]

average loss: 0.008056028746068478


 81%|████████  | 67103/82783 [51:23<16:03, 16.28it/s]

average loss: 0.009758219122886658


 81%|████████  | 67205/82783 [51:28<15:01, 17.29it/s]

average loss: 0.01573767140507698


 81%|████████▏ | 67304/82783 [51:33<14:39, 17.61it/s]

average loss: 0.009981931187212467


 81%|████████▏ | 67403/82783 [51:37<15:46, 16.25it/s]

average loss: 0.009021284990012646


 82%|████████▏ | 67505/82783 [51:42<14:21, 17.73it/s]

average loss: 0.011743379756808281


 82%|████████▏ | 67604/82783 [51:46<14:26, 17.52it/s]

average loss: 0.0078108226880431175


 82%|████████▏ | 67706/82783 [51:51<13:58, 17.97it/s]

average loss: 0.006034869700670242


 82%|████████▏ | 67805/82783 [51:56<14:00, 17.83it/s]

average loss: 0.006494456902146339


 82%|████████▏ | 67904/82783 [52:00<14:05, 17.59it/s]

average loss: 0.014268198981881142


 82%|████████▏ | 68006/82783 [52:05<13:36, 18.11it/s]

average loss: 0.0061448416672647


 82%|████████▏ | 68105/82783 [52:09<13:51, 17.66it/s]

average loss: 0.003924015909433365


 82%|████████▏ | 68204/82783 [52:14<13:51, 17.54it/s]

average loss: 0.0031154793687164783


 83%|████████▎ | 68303/82783 [52:19<14:39, 16.47it/s]

average loss: 0.01229636650532484


 83%|████████▎ | 68405/82783 [52:23<13:37, 17.59it/s]

average loss: 0.01382299792021513


 83%|████████▎ | 68504/82783 [52:28<13:28, 17.67it/s]

average loss: 0.011244113557040691


 83%|████████▎ | 68603/82783 [52:32<14:42, 16.08it/s]

average loss: 0.019078455865383148


 83%|████████▎ | 68705/82783 [52:37<13:26, 17.46it/s]

average loss: 0.016628744080662727


 83%|████████▎ | 68804/82783 [52:42<12:54, 18.06it/s]

average loss: 0.01003392692655325


 83%|████████▎ | 68903/82783 [52:46<14:19, 16.15it/s]

average loss: 0.007451247423887253


 83%|████████▎ | 69005/82783 [52:51<13:13, 17.36it/s]

average loss: 0.02206946536898613


 83%|████████▎ | 69104/82783 [52:56<13:11, 17.29it/s]

average loss: 0.016815152019262314


 84%|████████▎ | 69203/82783 [53:00<14:12, 15.93it/s]

average loss: 0.011713646352291107


 84%|████████▎ | 69304/82783 [53:05<13:31, 16.62it/s]

average loss: 0.012054041028022766


 84%|████████▍ | 69405/82783 [53:10<12:49, 17.39it/s]

average loss: 0.004951895214617252


 84%|████████▍ | 69504/82783 [53:14<12:33, 17.63it/s]

average loss: 0.006128688808530569


 84%|████████▍ | 69603/82783 [53:19<13:59, 15.70it/s]

average loss: 0.04430879279971123


 84%|████████▍ | 69705/82783 [53:23<12:31, 17.40it/s]

average loss: 0.008735204115509987


 84%|████████▍ | 69804/82783 [53:28<12:05, 17.90it/s]

average loss: 0.017566516995429993


 84%|████████▍ | 69903/82783 [53:33<13:29, 15.92it/s]

average loss: 0.012444357387721539


 85%|████████▍ | 70005/82783 [53:37<12:16, 17.34it/s]

average loss: 0.005681569688022137


 85%|████████▍ | 70104/82783 [53:42<12:21, 17.09it/s]

average loss: 0.006715143099427223


 85%|████████▍ | 70203/82783 [53:47<12:31, 16.75it/s]

average loss: 0.005444146227091551


 85%|████████▍ | 70305/82783 [53:51<11:52, 17.52it/s]

average loss: 0.02387886308133602


 85%|████████▌ | 70404/82783 [53:56<11:57, 17.25it/s]

average loss: 0.014457136392593384


 85%|████████▌ | 70503/82783 [54:00<12:42, 16.11it/s]

average loss: 0.01788720302283764


 85%|████████▌ | 70605/82783 [54:05<11:30, 17.62it/s]

average loss: 0.010160518810153008


 85%|████████▌ | 70704/82783 [54:10<11:32, 17.44it/s]

average loss: 0.011824768036603928


 86%|████████▌ | 70803/82783 [54:14<12:33, 15.89it/s]

average loss: 0.006223518401384354


 86%|████████▌ | 70905/82783 [54:19<11:10, 17.70it/s]

average loss: 0.007718070410192013


 86%|████████▌ | 71004/82783 [54:24<10:59, 17.86it/s]

average loss: 0.005840933881700039


 86%|████████▌ | 71103/82783 [54:28<11:46, 16.54it/s]

average loss: 0.024720139801502228


 86%|████████▌ | 71205/82783 [54:33<11:03, 17.46it/s]

average loss: 0.01048604678362608


 86%|████████▌ | 71304/82783 [54:37<10:44, 17.82it/s]

average loss: 0.010676047764718533


 86%|████████▋ | 71403/82783 [54:42<11:49, 16.03it/s]

average loss: 0.005750506184995174


 86%|████████▋ | 71505/82783 [54:47<10:40, 17.61it/s]

average loss: 0.004231026396155357


 86%|████████▋ | 71604/82783 [54:51<10:33, 17.66it/s]

average loss: 0.012654903344810009


 87%|████████▋ | 71703/82783 [54:56<11:57, 15.44it/s]

average loss: 0.008124817162752151


 87%|████████▋ | 71804/82783 [55:01<10:30, 17.42it/s]

average loss: 0.008371030911803246


 87%|████████▋ | 71905/82783 [55:05<10:31, 17.22it/s]

average loss: 0.010504837147891521


 87%|████████▋ | 72004/82783 [55:10<10:35, 16.96it/s]

average loss: 0.031905580312013626


 87%|████████▋ | 72103/82783 [55:14<11:19, 15.71it/s]

average loss: 0.01199404802173376


 87%|████████▋ | 72205/82783 [55:19<09:57, 17.70it/s]

average loss: 0.0034086350351572037


 87%|████████▋ | 72304/82783 [55:24<09:53, 17.66it/s]

average loss: 0.009297734126448631


 87%|████████▋ | 72403/82783 [55:28<10:50, 15.97it/s]

average loss: 0.024633577093482018


 88%|████████▊ | 72505/82783 [55:33<10:22, 16.50it/s]

average loss: 0.01091785542666912


 88%|████████▊ | 72604/82783 [55:38<09:47, 17.33it/s]

average loss: 0.005783191882073879


 88%|████████▊ | 72703/82783 [55:42<10:24, 16.13it/s]

average loss: 0.0027283686213195324


 88%|████████▊ | 72805/82783 [55:47<09:26, 17.61it/s]

average loss: 0.01214144192636013


 88%|████████▊ | 72904/82783 [55:52<09:36, 17.14it/s]

average loss: 0.010931700468063354


 88%|████████▊ | 73005/82783 [55:56<09:14, 17.64it/s]

average loss: 0.01093592494726181


 88%|████████▊ | 73104/82783 [56:01<09:11, 17.56it/s]

average loss: 0.0202484093606472


 88%|████████▊ | 73203/82783 [56:05<10:10, 15.70it/s]

average loss: 0.060590438544750214


 89%|████████▊ | 73305/82783 [56:10<09:06, 17.36it/s]

average loss: 0.016326047480106354


 89%|████████▊ | 73404/82783 [56:15<09:01, 17.32it/s]

average loss: 0.009628361091017723


 89%|████████▉ | 73503/82783 [56:19<09:43, 15.91it/s]

average loss: 0.009828958660364151


 89%|████████▉ | 73605/82783 [56:24<09:00, 16.98it/s]

average loss: 0.028122782707214355


 89%|████████▉ | 73704/82783 [56:29<08:33, 17.68it/s]

average loss: 0.009534569457173347


 89%|████████▉ | 73803/82783 [56:33<09:11, 16.29it/s]

average loss: 0.013542654924094677


 89%|████████▉ | 73905/82783 [56:38<08:12, 18.04it/s]

average loss: 0.011522002518177032


 89%|████████▉ | 74004/82783 [56:43<08:22, 17.48it/s]

average loss: 0.006337480153888464


 90%|████████▉ | 74103/82783 [56:47<08:57, 16.16it/s]

average loss: 0.004914736375212669


 90%|████████▉ | 74204/82783 [56:52<08:06, 17.62it/s]

average loss: 0.013099661096930504


 90%|████████▉ | 74303/82783 [56:56<08:54, 15.87it/s]

average loss: 0.011064055375754833


 90%|████████▉ | 74405/82783 [57:01<07:59, 17.46it/s]

average loss: 0.006975093856453896


 90%|████████▉ | 74504/82783 [57:06<07:54, 17.43it/s]

average loss: 0.004332421813160181


 90%|█████████ | 74603/82783 [57:10<08:32, 15.97it/s]

average loss: 0.003996390849351883


 90%|█████████ | 74705/82783 [57:15<07:37, 17.65it/s]

average loss: 0.007988998666405678


 90%|█████████ | 74804/82783 [57:20<07:32, 17.64it/s]

average loss: 0.009000757709145546


 90%|█████████ | 74903/82783 [57:24<08:11, 16.02it/s]

average loss: 0.009224026463925838


 91%|█████████ | 75005/82783 [57:29<07:42, 16.80it/s]

average loss: 0.008983176201581955


 91%|█████████ | 75104/82783 [57:34<07:28, 17.11it/s]

average loss: 0.011322191916406155


 91%|█████████ | 75203/82783 [57:38<08:01, 15.74it/s]

average loss: 0.023843882605433464


 91%|█████████ | 75305/82783 [57:43<07:01, 17.74it/s]

average loss: 0.004793920554220676


 91%|█████████ | 75404/82783 [57:48<07:12, 17.05it/s]

average loss: 0.006303762551397085


 91%|█████████ | 75503/82783 [57:52<07:37, 15.90it/s]

average loss: 0.015256493352353573


 91%|█████████▏| 75605/82783 [57:57<06:47, 17.63it/s]

average loss: 0.015557384118437767


 91%|█████████▏| 75704/82783 [58:02<06:41, 17.63it/s]

average loss: 0.004828562494367361


 92%|█████████▏| 75803/82783 [58:06<07:14, 16.08it/s]

average loss: 0.008273116312921047


 92%|█████████▏| 75904/82783 [58:11<06:29, 17.68it/s]

average loss: 0.006599606014788151


 92%|█████████▏| 76003/82783 [58:15<06:59, 16.15it/s]

average loss: 0.008977113291621208


 92%|█████████▏| 76105/82783 [58:20<06:22, 17.45it/s]

average loss: 0.008575123734772205


 92%|█████████▏| 76204/82783 [58:25<06:13, 17.62it/s]

average loss: 0.0037824760656803846


 92%|█████████▏| 76303/82783 [58:29<06:50, 15.78it/s]

average loss: 0.011151386424899101


 92%|█████████▏| 76405/82783 [58:34<05:56, 17.91it/s]

average loss: 0.004508316982537508


 92%|█████████▏| 76504/82783 [58:38<05:56, 17.61it/s]

average loss: 0.022600475698709488


 93%|█████████▎| 76603/82783 [58:43<06:28, 15.92it/s]

average loss: 0.007285742089152336


 93%|█████████▎| 76704/82783 [58:48<05:48, 17.44it/s]

average loss: 0.006484610494226217


 93%|█████████▎| 76803/82783 [58:52<06:13, 16.02it/s]

average loss: 0.01179767306894064


 93%|█████████▎| 76905/82783 [58:57<05:38, 17.39it/s]

average loss: 0.007048151921480894


 93%|█████████▎| 77004/82783 [59:02<05:29, 17.56it/s]

average loss: 0.01119375228881836


 93%|█████████▎| 77103/82783 [59:06<05:58, 15.85it/s]

average loss: 0.0022298144176602364


 93%|█████████▎| 77204/82783 [59:11<05:23, 17.22it/s]

average loss: 0.023249343037605286


 93%|█████████▎| 77303/82783 [59:16<06:08, 14.86it/s]

average loss: 0.012081722728908062


 94%|█████████▎| 77405/82783 [59:20<05:02, 17.77it/s]

average loss: 0.024612456560134888


 94%|█████████▎| 77504/82783 [59:25<05:04, 17.34it/s]

average loss: 0.012494903057813644


 94%|█████████▎| 77606/82783 [59:30<04:51, 17.75it/s]

average loss: 0.009559571743011475


 94%|█████████▍| 77705/82783 [59:34<04:57, 17.05it/s]

average loss: 0.015571322292089462


 94%|█████████▍| 77804/82783 [59:39<04:38, 17.87it/s]

average loss: 0.010882529430091381


 94%|█████████▍| 77903/82783 [59:43<05:05, 15.97it/s]

average loss: 0.01202276349067688


 94%|█████████▍| 78005/82783 [59:48<04:27, 17.84it/s]

average loss: 0.004074614495038986


 94%|█████████▍| 78104/82783 [59:53<04:28, 17.41it/s]

average loss: 0.022105291485786438


 94%|█████████▍| 78203/82783 [59:57<04:39, 16.39it/s]

average loss: 0.0038197836838662624


 95%|█████████▍| 78305/82783 [1:00:02<04:14, 17.62it/s]

average loss: 0.005166263319551945


 95%|█████████▍| 78404/82783 [1:00:06<04:15, 17.17it/s]

average loss: 0.013035736978054047


 95%|█████████▍| 78503/82783 [1:00:11<04:29, 15.89it/s]

average loss: 0.004308605566620827


 95%|█████████▍| 78605/82783 [1:00:16<03:59, 17.42it/s]

average loss: 0.030488386750221252


 95%|█████████▌| 78704/82783 [1:00:20<03:55, 17.29it/s]

average loss: 0.012207446619868279


 95%|█████████▌| 78803/82783 [1:00:25<04:11, 15.81it/s]

average loss: 0.00829758308827877


 95%|█████████▌| 78904/82783 [1:00:30<03:41, 17.55it/s]

average loss: 0.009392750449478626


 95%|█████████▌| 79005/82783 [1:00:34<03:51, 16.33it/s]

average loss: 0.018075089901685715


 96%|█████████▌| 79104/82783 [1:00:39<03:28, 17.63it/s]

average loss: 0.012471873313188553


 96%|█████████▌| 79203/82783 [1:00:44<03:46, 15.79it/s]

average loss: 0.004536286927759647


 96%|█████████▌| 79305/82783 [1:00:48<03:20, 17.33it/s]

average loss: 0.014065362513065338


 96%|█████████▌| 79404/82783 [1:00:53<03:11, 17.61it/s]

average loss: 0.010832466185092926


 96%|█████████▌| 79503/82783 [1:00:57<03:19, 16.42it/s]

average loss: 0.018945302814245224


 96%|█████████▌| 79605/82783 [1:01:02<03:12, 16.52it/s]

average loss: 0.0034843634348362684


 96%|█████████▋| 79704/82783 [1:01:07<02:59, 17.11it/s]

average loss: 0.01844196766614914


 96%|█████████▋| 79803/82783 [1:01:11<03:09, 15.75it/s]

average loss: 0.0073233433067798615


 97%|█████████▋| 79905/82783 [1:01:16<02:45, 17.44it/s]

average loss: 0.01760394126176834


 97%|█████████▋| 80004/82783 [1:01:21<02:39, 17.38it/s]

average loss: 0.010335841216146946


 97%|█████████▋| 80103/82783 [1:01:25<02:45, 16.24it/s]

average loss: 0.01302373968064785


 97%|█████████▋| 80205/82783 [1:01:30<02:29, 17.26it/s]

average loss: 0.013917459174990654


 97%|█████████▋| 80304/82783 [1:01:35<02:19, 17.71it/s]

average loss: 0.015162105672061443


 97%|█████████▋| 80403/82783 [1:01:39<02:28, 15.98it/s]

average loss: 0.008631523698568344


 97%|█████████▋| 80505/82783 [1:01:44<02:11, 17.38it/s]

average loss: 0.00972506869584322


 97%|█████████▋| 80604/82783 [1:01:48<02:04, 17.50it/s]

average loss: 0.013522882014513016


 97%|█████████▋| 80705/82783 [1:01:53<02:00, 17.31it/s]

average loss: 0.006240989547222853


 98%|█████████▊| 80804/82783 [1:01:58<01:52, 17.62it/s]

average loss: 0.007404200732707977


 98%|█████████▊| 80903/82783 [1:02:02<01:57, 16.02it/s]

average loss: 0.022687753662467003


 98%|█████████▊| 81004/82783 [1:02:07<01:44, 16.97it/s]

average loss: 0.002402637153863907


 98%|█████████▊| 81103/82783 [1:02:12<01:44, 16.12it/s]

average loss: 0.014116279780864716


 98%|█████████▊| 81205/82783 [1:02:16<01:31, 17.34it/s]

average loss: 0.010409928858280182


 98%|█████████▊| 81304/82783 [1:02:21<01:25, 17.36it/s]

average loss: 0.010271389037370682


 98%|█████████▊| 81404/82783 [1:02:25<01:21, 16.96it/s]

average loss: 0.0024975610431283712


 98%|█████████▊| 81503/82783 [1:02:30<01:22, 15.50it/s]

average loss: 0.004275467246770859


 99%|█████████▊| 81605/82783 [1:02:35<01:09, 16.98it/s]

average loss: 0.02025371417403221


 99%|█████████▊| 81704/82783 [1:02:39<01:02, 17.28it/s]

average loss: 0.012014586478471756


 99%|█████████▉| 81803/82783 [1:02:44<01:01, 15.98it/s]

average loss: 0.018329758197069168


 99%|█████████▉| 81904/82783 [1:02:49<00:52, 16.78it/s]

average loss: 0.011820683255791664


 99%|█████████▉| 82003/82783 [1:02:53<00:48, 15.96it/s]

average loss: 0.028638485819101334


 99%|█████████▉| 82105/82783 [1:02:58<00:38, 17.50it/s]

average loss: 0.008149327710270882


 99%|█████████▉| 82204/82783 [1:03:03<00:33, 17.44it/s]

average loss: 0.011515067890286446


 99%|█████████▉| 82303/82783 [1:03:07<00:30, 15.86it/s]

average loss: 0.006425307597965002


100%|█████████▉| 82405/82783 [1:03:12<00:21, 17.66it/s]

average loss: 0.006513225845992565


100%|█████████▉| 82504/82783 [1:03:17<00:15, 17.81it/s]

average loss: 0.010052001103758812


100%|█████████▉| 82603/82783 [1:03:21<00:11, 16.14it/s]

average loss: 0.002625333145260811


100%|█████████▉| 82705/82783 [1:03:26<00:04, 17.35it/s]

average loss: 0.006878917571157217


100%|██████████| 82783/82783 [1:03:29<00:00, 21.73it/s]


<font size="5">Create DALLE Model</font>

In [19]:
tokenizer = SimpleTokenizer()

dalle = DALLE(
    dim = 1024,
    vae = vae,                                 # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = tokenizer.vocab_size,    # vocab size for text
    text_seq_len = 256,                        # text sequence length
    depth = 1,                                 # should aim to be 64
    heads = 16,                                # attention heads
    dim_head = 64,                             # attention head dimension
    attn_dropout = 0.1,                        # attention dropout
    ff_dropout = 0.1                           # feedforward dropout
).to(device)

if os.path.exists(dalle_save_path):
    dalle.load_state_dict(torch.load(dalle_save_path))

RuntimeError: Error(s) in loading state_dict for DALLE:
	Missing key(s) in state_dict: "vae.codebook.weight", "vae.encoder.0.0.weight", "vae.encoder.0.0.bias", "vae.encoder.1.0.weight", "vae.encoder.1.0.bias", "vae.encoder.2.0.weight", "vae.encoder.2.0.bias", "vae.encoder.3.net.0.weight", "vae.encoder.3.net.0.bias", "vae.encoder.3.net.2.weight", "vae.encoder.3.net.2.bias", "vae.encoder.3.net.4.weight", "vae.encoder.3.net.4.bias", "vae.encoder.4.weight", "vae.encoder.4.bias", "vae.decoder.0.weight", "vae.decoder.0.bias", "vae.decoder.1.net.0.weight", "vae.decoder.1.net.0.bias", "vae.decoder.1.net.2.weight", "vae.decoder.1.net.2.bias", "vae.decoder.1.net.4.weight", "vae.decoder.1.net.4.bias", "vae.decoder.2.0.weight", "vae.decoder.2.0.bias", "vae.decoder.3.0.weight", "vae.decoder.3.0.bias", "vae.decoder.4.0.weight", "vae.decoder.4.0.bias", "vae.decoder.5.weight", "vae.decoder.5.bias". 
	Unexpected key(s) in state_dict: "vae.enc.blocks.input.w", "vae.enc.blocks.input.b", "vae.enc.blocks.group_1.block_1.res_path.conv_1.w", "vae.enc.blocks.group_1.block_1.res_path.conv_1.b", "vae.enc.blocks.group_1.block_1.res_path.conv_2.w", "vae.enc.blocks.group_1.block_1.res_path.conv_2.b", "vae.enc.blocks.group_1.block_1.res_path.conv_3.w", "vae.enc.blocks.group_1.block_1.res_path.conv_3.b", "vae.enc.blocks.group_1.block_1.res_path.conv_4.w", "vae.enc.blocks.group_1.block_1.res_path.conv_4.b", "vae.enc.blocks.group_1.block_2.res_path.conv_1.w", "vae.enc.blocks.group_1.block_2.res_path.conv_1.b", "vae.enc.blocks.group_1.block_2.res_path.conv_2.w", "vae.enc.blocks.group_1.block_2.res_path.conv_2.b", "vae.enc.blocks.group_1.block_2.res_path.conv_3.w", "vae.enc.blocks.group_1.block_2.res_path.conv_3.b", "vae.enc.blocks.group_1.block_2.res_path.conv_4.w", "vae.enc.blocks.group_1.block_2.res_path.conv_4.b", "vae.enc.blocks.group_2.block_1.id_path.w", "vae.enc.blocks.group_2.block_1.id_path.b", "vae.enc.blocks.group_2.block_1.res_path.conv_1.w", "vae.enc.blocks.group_2.block_1.res_path.conv_1.b", "vae.enc.blocks.group_2.block_1.res_path.conv_2.w", "vae.enc.blocks.group_2.block_1.res_path.conv_2.b", "vae.enc.blocks.group_2.block_1.res_path.conv_3.w", "vae.enc.blocks.group_2.block_1.res_path.conv_3.b", "vae.enc.blocks.group_2.block_1.res_path.conv_4.w", "vae.enc.blocks.group_2.block_1.res_path.conv_4.b", "vae.enc.blocks.group_2.block_2.res_path.conv_1.w", "vae.enc.blocks.group_2.block_2.res_path.conv_1.b", "vae.enc.blocks.group_2.block_2.res_path.conv_2.w", "vae.enc.blocks.group_2.block_2.res_path.conv_2.b", "vae.enc.blocks.group_2.block_2.res_path.conv_3.w", "vae.enc.blocks.group_2.block_2.res_path.conv_3.b", "vae.enc.blocks.group_2.block_2.res_path.conv_4.w", "vae.enc.blocks.group_2.block_2.res_path.conv_4.b", "vae.enc.blocks.group_3.block_1.id_path.w", "vae.enc.blocks.group_3.block_1.id_path.b", "vae.enc.blocks.group_3.block_1.res_path.conv_1.w", "vae.enc.blocks.group_3.block_1.res_path.conv_1.b", "vae.enc.blocks.group_3.block_1.res_path.conv_2.w", "vae.enc.blocks.group_3.block_1.res_path.conv_2.b", "vae.enc.blocks.group_3.block_1.res_path.conv_3.w", "vae.enc.blocks.group_3.block_1.res_path.conv_3.b", "vae.enc.blocks.group_3.block_1.res_path.conv_4.w", "vae.enc.blocks.group_3.block_1.res_path.conv_4.b", "vae.enc.blocks.group_3.block_2.res_path.conv_1.w", "vae.enc.blocks.group_3.block_2.res_path.conv_1.b", "vae.enc.blocks.group_3.block_2.res_path.conv_2.w", "vae.enc.blocks.group_3.block_2.res_path.conv_2.b", "vae.enc.blocks.group_3.block_2.res_path.conv_3.w", "vae.enc.blocks.group_3.block_2.res_path.conv_3.b", "vae.enc.blocks.group_3.block_2.res_path.conv_4.w", "vae.enc.blocks.group_3.block_2.res_path.conv_4.b", "vae.enc.blocks.group_4.block_1.id_path.w", "vae.enc.blocks.group_4.block_1.id_path.b", "vae.enc.blocks.group_4.block_1.res_path.conv_1.w", "vae.enc.blocks.group_4.block_1.res_path.conv_1.b", "vae.enc.blocks.group_4.block_1.res_path.conv_2.w", "vae.enc.blocks.group_4.block_1.res_path.conv_2.b", "vae.enc.blocks.group_4.block_1.res_path.conv_3.w", "vae.enc.blocks.group_4.block_1.res_path.conv_3.b", "vae.enc.blocks.group_4.block_1.res_path.conv_4.w", "vae.enc.blocks.group_4.block_1.res_path.conv_4.b", "vae.enc.blocks.group_4.block_2.res_path.conv_1.w", "vae.enc.blocks.group_4.block_2.res_path.conv_1.b", "vae.enc.blocks.group_4.block_2.res_path.conv_2.w", "vae.enc.blocks.group_4.block_2.res_path.conv_2.b", "vae.enc.blocks.group_4.block_2.res_path.conv_3.w", "vae.enc.blocks.group_4.block_2.res_path.conv_3.b", "vae.enc.blocks.group_4.block_2.res_path.conv_4.w", "vae.enc.blocks.group_4.block_2.res_path.conv_4.b", "vae.enc.blocks.output.conv.w", "vae.enc.blocks.output.conv.b", "vae.dec.blocks.input.w", "vae.dec.blocks.input.b", "vae.dec.blocks.group_1.block_1.id_path.w", "vae.dec.blocks.group_1.block_1.id_path.b", "vae.dec.blocks.group_1.block_1.res_path.conv_1.w", "vae.dec.blocks.group_1.block_1.res_path.conv_1.b", "vae.dec.blocks.group_1.block_1.res_path.conv_2.w", "vae.dec.blocks.group_1.block_1.res_path.conv_2.b", "vae.dec.blocks.group_1.block_1.res_path.conv_3.w", "vae.dec.blocks.group_1.block_1.res_path.conv_3.b", "vae.dec.blocks.group_1.block_1.res_path.conv_4.w", "vae.dec.blocks.group_1.block_1.res_path.conv_4.b", "vae.dec.blocks.group_1.block_2.res_path.conv_1.w", "vae.dec.blocks.group_1.block_2.res_path.conv_1.b", "vae.dec.blocks.group_1.block_2.res_path.conv_2.w", "vae.dec.blocks.group_1.block_2.res_path.conv_2.b", "vae.dec.blocks.group_1.block_2.res_path.conv_3.w", "vae.dec.blocks.group_1.block_2.res_path.conv_3.b", "vae.dec.blocks.group_1.block_2.res_path.conv_4.w", "vae.dec.blocks.group_1.block_2.res_path.conv_4.b", "vae.dec.blocks.group_2.block_1.id_path.w", "vae.dec.blocks.group_2.block_1.id_path.b", "vae.dec.blocks.group_2.block_1.res_path.conv_1.w", "vae.dec.blocks.group_2.block_1.res_path.conv_1.b", "vae.dec.blocks.group_2.block_1.res_path.conv_2.w", "vae.dec.blocks.group_2.block_1.res_path.conv_2.b", "vae.dec.blocks.group_2.block_1.res_path.conv_3.w", "vae.dec.blocks.group_2.block_1.res_path.conv_3.b", "vae.dec.blocks.group_2.block_1.res_path.conv_4.w", "vae.dec.blocks.group_2.block_1.res_path.conv_4.b", "vae.dec.blocks.group_2.block_2.res_path.conv_1.w", "vae.dec.blocks.group_2.block_2.res_path.conv_1.b", "vae.dec.blocks.group_2.block_2.res_path.conv_2.w", "vae.dec.blocks.group_2.block_2.res_path.conv_2.b", "vae.dec.blocks.group_2.block_2.res_path.conv_3.w", "vae.dec.blocks.group_2.block_2.res_path.conv_3.b", "vae.dec.blocks.group_2.block_2.res_path.conv_4.w", "vae.dec.blocks.group_2.block_2.res_path.conv_4.b", "vae.dec.blocks.group_3.block_1.id_path.w", "vae.dec.blocks.group_3.block_1.id_path.b", "vae.dec.blocks.group_3.block_1.res_path.conv_1.w", "vae.dec.blocks.group_3.block_1.res_path.conv_1.b", "vae.dec.blocks.group_3.block_1.res_path.conv_2.w", "vae.dec.blocks.group_3.block_1.res_path.conv_2.b", "vae.dec.blocks.group_3.block_1.res_path.conv_3.w", "vae.dec.blocks.group_3.block_1.res_path.conv_3.b", "vae.dec.blocks.group_3.block_1.res_path.conv_4.w", "vae.dec.blocks.group_3.block_1.res_path.conv_4.b", "vae.dec.blocks.group_3.block_2.res_path.conv_1.w", "vae.dec.blocks.group_3.block_2.res_path.conv_1.b", "vae.dec.blocks.group_3.block_2.res_path.conv_2.w", "vae.dec.blocks.group_3.block_2.res_path.conv_2.b", "vae.dec.blocks.group_3.block_2.res_path.conv_3.w", "vae.dec.blocks.group_3.block_2.res_path.conv_3.b", "vae.dec.blocks.group_3.block_2.res_path.conv_4.w", "vae.dec.blocks.group_3.block_2.res_path.conv_4.b", "vae.dec.blocks.group_4.block_1.id_path.w", "vae.dec.blocks.group_4.block_1.id_path.b", "vae.dec.blocks.group_4.block_1.res_path.conv_1.w", "vae.dec.blocks.group_4.block_1.res_path.conv_1.b", "vae.dec.blocks.group_4.block_1.res_path.conv_2.w", "vae.dec.blocks.group_4.block_1.res_path.conv_2.b", "vae.dec.blocks.group_4.block_1.res_path.conv_3.w", "vae.dec.blocks.group_4.block_1.res_path.conv_3.b", "vae.dec.blocks.group_4.block_1.res_path.conv_4.w", "vae.dec.blocks.group_4.block_1.res_path.conv_4.b", "vae.dec.blocks.group_4.block_2.res_path.conv_1.w", "vae.dec.blocks.group_4.block_2.res_path.conv_1.b", "vae.dec.blocks.group_4.block_2.res_path.conv_2.w", "vae.dec.blocks.group_4.block_2.res_path.conv_2.b", "vae.dec.blocks.group_4.block_2.res_path.conv_3.w", "vae.dec.blocks.group_4.block_2.res_path.conv_3.b", "vae.dec.blocks.group_4.block_2.res_path.conv_4.w", "vae.dec.blocks.group_4.block_2.res_path.conv_4.b", "vae.dec.blocks.output.conv.w", "vae.dec.blocks.output.conv.b". 

<font size="5">Train DALLE Model</font>

In [None]:
train_size = len(train_data)
idx_list = range(0, train_size, batch_size)

opt = Adam(
    get_trainable_params(dalle),
    lr = 3e-4,
    # weight_decay=0.01,
    # betas = (0.9, 0.999)
)
sched = ReduceLROnPlateau(
    opt,
    mode="min",
    factor=0.5,
    patience=10,
    cooldown=10,
    min_lr=1e-6,
    verbose=True,
)

for curr_epoch in range(epoch):
    print("Run training dalle ...")
    print(f"Epoch {curr_epoch+1} / {epoch}")
    
    for batch_idx in tqdm(idx_list):
        if (batch_idx + batch_size) > train_size - 1:
            iter_idx = range(batch_idx, train_size, 1)
        else:
            iter_idx = range(batch_idx, batch_idx+batch_size, 1)

        batch_len = 0
        total_loss = torch.tensor(0., device=device)

        for curr_idx in iter_idx:
            image, target = train_data[curr_idx]
            image = image.unsqueeze(0).type(torch.FloatTensor).to(device)
            
            texts = tokenizer.tokenize(target).type(torch.LongTensor).to(device)
            
            for text in texts:
                if total_loss == torch.tensor(0., device=device):
                    total_loss = dalle(text.unsqueeze(0), image, return_loss=True)
                else:
                    total_loss += dalle(text.unsqueeze(0), image, return_loss=True)
                batch_len += 1
                
        avg_loss = total_loss / batch_len

        opt.zero_grad()
        avg_loss.backward()
        opt.step()
        
        if batch_idx % 100 == 0:
            torch.save(dalle.state_dict(), dalle_save_path)
            print(f"average loss: {avg_loss.data}")
        
    sched.step(avg_loss)

torch.save(dalle.state_dict(), dalle_save_path)