## Create PyTorch Dataset

In [1]:
import transformers
print(transformers.__version__)

4.29.1


In [2]:
import pandas as pd
from glob import  glob
import json
from tqdm.notebook import trange, tqdm
from sklearn.model_selection import StratifiedKFold
import torch
import numpy as np
import random
import os
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from torch.cuda.amp import GradScaler, autocast

In [3]:
image_path = glob('./train/images/*')

In [4]:
label_path = glob('./train/annotations/*')

In [5]:
assert len(image_path) == 60578

In [6]:
assert len(label_path)== 60578

### Understanding `max_patches` argument

The paper introduces a new paradigm for processing the input image. It takes the image and create `n_patches` aspect-ratio preserving patches, and concatenates the remaining sequence with padding tokens to finally get `max_patches` patches. It appears that this argument is quite crucial for training and evaluation, as the model becomes very sensitive to this parameter.

For the sake of our example, we will fine-tune a model with `max_patches=1024`.

Note that most of the `-base` models have been fine-tuned with `max_patches=2048`, and `4096` for `-large` models.

In [7]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
from tqdm.notebook import tqdm
MAX_PATCHES = 2048

class ImageCaptioningDataset(Dataset):
    def __init__(self, df, processor):
        self.dataset = df
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        row = self.dataset.iloc[idx, :]
        image = Image.open(row.image_path)
        #display(image)
        encoding = self.processor(images=image,
                                  #prompt
                                  #文本检测 去检测x y轴的label x-axis<X,X,X,X,> Y-axis<y,y,y,y>
                                  # DBnet
                                  text="Generate underlying data table of the figure below:",
                                  font_path="arial.ttf",
                                  return_tensors="pt",
                                  add_special_tokens=True, max_patches=MAX_PATCHES)
        
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        encoding["text"] = row.label 
        return encoding

## Load model and processor

In [8]:
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration

processor = Pix2StructProcessor.from_pretrained("google/matcha-base")
model = Pix2StructForConditionalGeneration.from_pretrained("google/matcha-plotqa-v2")

Now that we have loaded the processor, let's load the dataset and the dataloader:

In [9]:
def collator(batch):
    new_batch = {"flattened_patches": [], "attention_mask": []}
    texts = [item["text"] for item in batch]
    # print(texts)
    text_inputs = processor.tokenizer(text=texts,
                                      padding="max_length",
                                      return_tensors="pt",
                                      add_special_tokens=True,
                                      max_length=512,
                                      truncation=True
                                      )

    new_batch["labels"] = text_inputs.input_ids

    for item in batch:
        new_batch["flattened_patches"].append(item["flattened_patches"])
        new_batch["attention_mask"].append(item["attention_mask"])

    new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
    new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

    return new_batch

In [10]:
df = pd.read_csv('train_with_fold.csv')
print(len(df))
train_df = df[df['fold'] != 0]
train_df.reset_index(drop=True, inplace=True)

60578


In [13]:
class CFG:
    scheduler = 'cosine'  # ['linear', 'cosine']
    batch_scheduler = True
    num_cycles = 0.5  # 1.5
    num_warmup_steps = 0.2
    max_input_length = 130
    epochs = 10  # 5
    encoder_lr = 10e-6
    decoder_lr = 10e-6
    min_lr = 0.5e-6
    eps = 1e-6
    betas = (0.9, 0.999)
    weight_decay = 0
    num_fold = 5
    batch_size = 2
    seed = 1006
    num_workers = 2
    device='cuda:1'
    print_freq = 100

In [14]:
train_dataset = ImageCaptioningDataset(train_df, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=CFG.batch_size, collate_fn=collator, pin_memory=True,
                                  prefetch_factor=40, num_workers=2)

In [15]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(CFG.seed)

In [16]:
def get_scheduler(cfg, optimizer, num_train_steps):
    cfg.num_warmup_steps = cfg.num_warmup_steps * num_train_steps
    if cfg.scheduler == 'linear':
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps
        )
    elif cfg.scheduler == 'cosine':
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps,
            num_cycles=cfg.num_cycles
        )
    return scheduler

num_train_steps = int(len(train_dataset) / CFG.batch_size * CFG.epochs)


## Train the model

Let's train the model! Run the simply the cell below for training the model. We have observed that finding the best hyper-parameters was quite challenging and required a lot of trials and errors, as the model can easily enter in "collapse-model" (always predicting the same output, no matter the input) if the HP are not chosen correctly. In this example, we found out that using `AdamW` optimizer with `lr=1e-5` seemed to be the best approach.

Let's also print the generation output of the model each 10 epochs!

Bear in mind that the model took some time to converge, for instance to get decent results we had to let the script run for ~1hour.

In [17]:
EPOCHS = 10

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scheduler = get_scheduler(CFG, optimizer, num_train_steps)
device = "cuda:1" if torch.cuda.is_available() else "cpu"
print(device)
model.to(device)
scaler = torch.cuda.amp.GradScaler()
model.train()

cuda:1


Pix2StructForConditionalGeneration(
  (encoder): Pix2StructVisionModel(
    (embeddings): Pix2StructVisionEmbeddings(
      (patch_projection): Linear(in_features=768, out_features=768, bias=True)
      (row_embedder): Embedding(4096, 768)
      (column_embedder): Embedding(4096, 768)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Pix2StructVisionEncoder(
      (layer): ModuleList(
        (0): Pix2StructVisionLayer(
          (attention): Pix2StructVisionAttention(
            (query): Linear(in_features=768, out_features=768, bias=False)
            (key): Linear(in_features=768, out_features=768, bias=False)
            (value): Linear(in_features=768, out_features=768, bias=False)
            (output): Linear(in_features=768, out_features=768, bias=False)
          )
          (mlp): Pix2StructVisionMlp(
            (wi_0): Linear(in_features=768, out_features=2048, bias=False)
            (wi_1): Linear(in_features=768, out_features=2048, bias=False)
         

In [20]:
loss_file = open("loss.txt","w")

for epoch in range(CFG.epochs):
    print("Epoch:", epoch)
    for idx, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        labels = batch.pop("labels").to(device)
        flattened_patches = batch.pop("flattened_patches").to(device)
        attention_mask = batch.pop("attention_mask").to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(flattened_patches=flattened_patches,
                            attention_mask=attention_mask,
                            labels=labels)

        loss = outputs.loss

        scaler.scale(loss).backward()
        #grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1000)
        # Unscales gradients and calls
        # or skips optimizer.step()
        scaler.step(optimizer)
        # Updates the scale for next iteration
        scaler.update()
        scheduler.step()
        if idx % 100 == 0:
            print("Loss:", loss.item(), f'lr : {scheduler.get_lr()[0]:.6f} ', sep=' ')
            loss_file.write(f"Epoch: {epoch}, Iteration: {idx}, Loss: {loss.item()}\n")
            loss_file.flush()
        

    if (epoch + 1) % 1 == 0:
        torch.save(model.state_dict(), f'./matcha_v1/matcha_{epoch}.bin')

loss_file.close()
        

Epoch: 0


  0%|          | 0/24231 [00:00<?, ?it/s]

Loss: 0.6629443764686584 lr : 0.000010 
Loss: 0.7359920144081116 lr : 0.000010 
Loss: 1.233712911605835 lr : 0.000010 
Loss: 1.0288405418395996 lr : 0.000010 
Loss: 0.858933687210083 lr : 0.000010 
Loss: 1.2682100534439087 lr : 0.000010 
Loss: 0.7170652747154236 lr : 0.000010 
Loss: 0.6052768230438232 lr : 0.000010 
Loss: 0.481381893157959 lr : 0.000010 
Loss: 0.5676138401031494 lr : 0.000010 
Loss: 1.1902427673339844 lr : 0.000010 
Loss: 0.6400188207626343 lr : 0.000010 
Loss: 0.745026171207428 lr : 0.000010 
Loss: 1.1922599077224731 lr : 0.000010 
Loss: 0.5645087361335754 lr : 0.000010 
Loss: 0.31349435448646545 lr : 0.000010 


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Loss: 1.0860344171524048 lr : 0.000010 
Loss: 0.7062138319015503 lr : 0.000010 
Loss: 0.8567994832992554 lr : 0.000010 
Loss: 0.8080344200134277 lr : 0.000010 
Loss: 0.82718825340271 lr : 0.000010 
Loss: 0.40620002150535583 lr : 0.000010 
Loss: 0.68768310546875 lr : 0.000010 
Loss: 0.650678277015686 lr : 0.000010 
Loss: 0.7405890822410583 lr : 0.000010 
Loss: 1.1742229461669922 lr : 0.000010 
Loss: 0.9242892265319824 lr : 0.000010 
Loss: 1.201245665550232 lr : 0.000010 
Loss: 1.696030855178833 lr : 0.000010 
Loss: 0.9377285242080688 lr : 0.000010 
Loss: 0.912734866142273 lr : 0.000010 
Loss: 0.6134945750236511 lr : 0.000010 
Loss: 0.7733114957809448 lr : 0.000010 
Loss: 1.234897255897522 lr : 0.000010 
Loss: 0.94827800989151 lr : 0.000010 
Loss: 1.061649203300476 lr : 0.000010 
Loss: 0.836963951587677 lr : 0.000010 
Loss: 1.1883182525634766 lr : 0.000010 
Loss: 0.46650922298431396 lr : 0.000010 
Loss: 1.118831753730774 lr : 0.000010 
Loss: 0.8363049030303955 lr : 0.000010 
Loss: 0.9104

  0%|          | 0/24231 [00:00<?, ?it/s]

Loss: 0.7556231617927551 lr : 0.000010 
Loss: 0.03452817723155022 lr : 0.000010 
Loss: 1.0818331241607666 lr : 0.000010 
Loss: 1.1988829374313354 lr : 0.000010 
Loss: 1.0498042106628418 lr : 0.000010 
Loss: 0.8261632919311523 lr : 0.000010 
Loss: 0.6496695876121521 lr : 0.000010 
Loss: 1.0359430313110352 lr : 0.000010 
Loss: 0.6875914931297302 lr : 0.000010 
Loss: 0.8585395812988281 lr : 0.000010 
Loss: 0.7067690491676331 lr : 0.000010 
Loss: 0.9208577275276184 lr : 0.000010 
Loss: 0.6945407390594482 lr : 0.000010 
Loss: 0.7719060182571411 lr : 0.000010 
Loss: 0.9553905725479126 lr : 0.000010 
Loss: 0.5768434405326843 lr : 0.000010 
Loss: 1.2601056098937988 lr : 0.000010 
Loss: 1.325897455215454 lr : 0.000010 
Loss: 0.303303599357605 lr : 0.000010 
Loss: 0.38255560398101807 lr : 0.000010 
Loss: 0.8272032737731934 lr : 0.000010 
Loss: 0.3749936521053314 lr : 0.000010 
Loss: 0.6913418769836426 lr : 0.000010 
Loss: 0.6609581708908081 lr : 0.000010 
Loss: 0.7606866359710693 lr : 0.000010 


Loss: 0.4050387740135193 lr : 0.000009 
Loss: 1.107010841369629 lr : 0.000009 
Loss: 0.23013031482696533 lr : 0.000009 
Loss: 1.1122790575027466 lr : 0.000009 
Loss: 0.5626463890075684 lr : 0.000009 
Loss: 1.1351053714752197 lr : 0.000009 
Loss: 0.5769360065460205 lr : 0.000009 
Loss: 1.1357449293136597 lr : 0.000009 
Loss: 0.7697697877883911 lr : 0.000009 
Loss: 0.7477922439575195 lr : 0.000009 
Loss: 0.8572719097137451 lr : 0.000009 
Loss: 0.6417900919914246 lr : 0.000009 
Loss: 0.6547920107841492 lr : 0.000009 
Loss: 0.5609573721885681 lr : 0.000009 
Loss: 1.3102010488510132 lr : 0.000009 
Loss: 0.9055250883102417 lr : 0.000009 
Loss: 1.0428696870803833 lr : 0.000009 
Loss: 0.0032809986732900143 lr : 0.000009 
Loss: 1.142601728439331 lr : 0.000009 
Loss: 0.47688549757003784 lr : 0.000009 
Loss: 0.8381931781768799 lr : 0.000009 
Loss: 1.4493244886398315 lr : 0.000009 
Loss: 0.9181467890739441 lr : 0.000009 
Loss: 0.7186872363090515 lr : 0.000009 
Loss: 0.6988481283187866 lr : 0.00000

  0%|          | 0/24231 [00:00<?, ?it/s]

Loss: 0.7235326766967773 lr : 0.000009 
Loss: 1.192431926727295 lr : 0.000009 
Loss: 0.6627061367034912 lr : 0.000009 
Loss: 0.5902383923530579 lr : 0.000009 
Loss: 0.8919805288314819 lr : 0.000009 
Loss: 1.015995740890503 lr : 0.000009 
Loss: 0.5488107204437256 lr : 0.000009 
Loss: 0.6607364416122437 lr : 0.000008 
Loss: 1.4169526100158691 lr : 0.000008 
Loss: 0.955756425857544 lr : 0.000008 
Loss: 0.8035165071487427 lr : 0.000008 
Loss: 1.1968417167663574 lr : 0.000008 
Loss: 0.6570849418640137 lr : 0.000008 
Loss: 1.0051250457763672 lr : 0.000008 
Loss: 0.6556007862091064 lr : 0.000008 
Loss: 0.6258431673049927 lr : 0.000008 
Loss: 0.8131260275840759 lr : 0.000008 
Loss: 1.1318644285202026 lr : 0.000008 
Loss: 0.44483518600463867 lr : 0.000008 
Loss: 0.7023304104804993 lr : 0.000008 
Loss: 0.8564227819442749 lr : 0.000008 
Loss: 0.999812662601471 lr : 0.000008 
Loss: 0.845029354095459 lr : 0.000008 
Loss: 1.015466570854187 lr : 0.000008 
Loss: 1.106616497039795 lr : 0.000008 
Loss: 

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Loss: 1.118545413017273 lr : 0.000007 
Loss: 0.5234410166740417 lr : 0.000007 
Loss: 0.28771212697029114 lr : 0.000007 
Loss: 0.6947903633117676 lr : 0.000007 
Loss: 0.8874111175537109 lr : 0.000007 
Loss: 0.6647882461547852 lr : 0.000007 
Loss: 0.811589777469635 lr : 0.000007 
Loss: 0.8131832480430603 lr : 0.000007 
Loss: 0.48346948623657227 lr : 0.000007 
Loss: 0.6390827894210815 lr : 0.000007 
Loss: 0.6688669323921204 lr : 0.000007 
Loss: 0.5917144417762756 lr : 0.000007 
Loss: 1.029126763343811 lr : 0.000007 
Loss: 0.7917524576187134 lr : 0.000007 
Loss: 0.887781023979187 lr : 0.000007 
Loss: 0.6180808544158936 lr : 0.000007 
Loss: 0.9798905849456787 lr : 0.000007 
Loss: 0.5652866363525391 lr : 0.000007 
Loss: 0.6522808074951172 lr : 0.000007 
Loss: 1.161333680152893 lr : 0.000007 
Loss: 0.639288067817688 lr : 0.000007 
Loss: 1.4931292533874512 lr : 0.000007 
Loss: 0.5799635648727417 lr : 0.000007 
Loss: 1.0596197843551636 lr : 0.000007 
Loss: 0.8263518810272217 lr : 0.000007 
Loss

  0%|          | 0/24231 [00:00<?, ?it/s]

Loss: 0.7634539008140564 lr : 0.000007 
Loss: 0.15728087723255157 lr : 0.000007 
Loss: 0.6131327152252197 lr : 0.000007 
Loss: 0.8122766613960266 lr : 0.000007 
Loss: 0.6121813058853149 lr : 0.000007 
Loss: 1.1065603494644165 lr : 0.000007 
Loss: 0.13333803415298462 lr : 0.000007 
Loss: 1.0836213827133179 lr : 0.000007 
Loss: 0.7383244037628174 lr : 0.000007 
Loss: 1.1530327796936035 lr : 0.000007 
Loss: 0.5047546029090881 lr : 0.000007 
Loss: 0.8137045502662659 lr : 0.000007 
Loss: 0.8050179481506348 lr : 0.000007 
Loss: 0.9557466506958008 lr : 0.000007 
Loss: 0.5749271512031555 lr : 0.000007 
Loss: 0.33513879776000977 lr : 0.000007 
Loss: 0.3584668040275574 lr : 0.000007 
Loss: 0.36349427700042725 lr : 0.000007 
Loss: 0.5554749965667725 lr : 0.000007 
Loss: 1.5390669107437134 lr : 0.000007 
Loss: 0.40759602189064026 lr : 0.000007 
Loss: 0.6623473167419434 lr : 0.000007 
Loss: 0.6113099455833435 lr : 0.000007 
Loss: 1.3107919692993164 lr : 0.000007 
Loss: 0.34883350133895874 lr : 0.00

Loss: 0.5043783187866211 lr : 0.000005 
Loss: 0.4951121211051941 lr : 0.000005 
Loss: 0.7514123916625977 lr : 0.000005 
Loss: 1.1417346000671387 lr : 0.000005 
Loss: 0.6692086458206177 lr : 0.000005 
Loss: 0.9038362503051758 lr : 0.000005 
Loss: 0.46428561210632324 lr : 0.000005 
Loss: 1.001158595085144 lr : 0.000005 
Loss: 0.5371847152709961 lr : 0.000005 
Loss: 0.519828736782074 lr : 0.000005 
Loss: 0.6938729882240295 lr : 0.000005 
Loss: 0.5223884582519531 lr : 0.000005 
Loss: 0.7054895162582397 lr : 0.000005 
Loss: 0.2997346818447113 lr : 0.000005 
Loss: 0.5559055209159851 lr : 0.000005 
Loss: 0.6503329873085022 lr : 0.000005 
Loss: 0.9274314641952515 lr : 0.000005 
Loss: 0.6246740818023682 lr : 0.000005 
Loss: 0.4556489884853363 lr : 0.000005 
Loss: 0.8858939409255981 lr : 0.000005 
Loss: 0.7919100522994995 lr : 0.000005 
Loss: 0.6550304889678955 lr : 0.000005 
Loss: 0.43188223242759705 lr : 0.000005 
Loss: 0.6248058676719666 lr : 0.000005 
Loss: 0.41390520334243774 lr : 0.000005 

  0%|          | 0/24231 [00:00<?, ?it/s]

Loss: 0.6827391386032104 lr : 0.000005 
Loss: 0.8731006979942322 lr : 0.000005 
Loss: 0.5487709641456604 lr : 0.000005 
Loss: 0.7064599394798279 lr : 0.000005 
Loss: 0.6545484662055969 lr : 0.000005 
Loss: 0.33575066924095154 lr : 0.000005 
Loss: 0.42181992530822754 lr : 0.000005 
Loss: 0.9168716669082642 lr : 0.000005 
Loss: 0.5497962236404419 lr : 0.000005 
Loss: 0.6182078123092651 lr : 0.000005 
Loss: 1.1085460186004639 lr : 0.000005 
Loss: 0.5720779895782471 lr : 0.000005 
Loss: 0.17523843050003052 lr : 0.000005 
Loss: 0.9560202360153198 lr : 0.000005 
Loss: 1.1202325820922852 lr : 0.000005 
Loss: 0.8056736588478088 lr : 0.000005 
Loss: 0.32379767298698425 lr : 0.000005 
Loss: 0.40775689482688904 lr : 0.000005 
Loss: 0.7668707370758057 lr : 0.000005 
Loss: 0.7287191152572632 lr : 0.000005 
Loss: 0.5547690391540527 lr : 0.000005 
Loss: 1.5502018928527832 lr : 0.000005 
Loss: 0.7821994423866272 lr : 0.000005 
Loss: 0.4362436830997467 lr : 0.000005 
Loss: 0.816501796245575 lr : 0.0000

Loss: 1.2012428045272827 lr : 0.000001 
Loss: 1.2473483085632324 lr : 0.000001 
Loss: 0.2622053921222687 lr : 0.000001 
Loss: 0.9974551796913147 lr : 0.000001 
Loss: 0.5858674645423889 lr : 0.000001 
Loss: 1.6607046127319336 lr : 0.000001 
Loss: 1.1372934579849243 lr : 0.000001 
Loss: 0.7832227945327759 lr : 0.000001 
Loss: 0.571606457233429 lr : 0.000001 
Loss: 0.6467158198356628 lr : 0.000001 
Loss: 1.0601569414138794 lr : 0.000001 
Loss: 0.6577380299568176 lr : 0.000001 
Loss: 0.5495434999465942 lr : 0.000001 
Loss: 0.8550590872764587 lr : 0.000001 
Loss: 0.5401104688644409 lr : 0.000001 
Loss: 0.9272490739822388 lr : 0.000001 
Loss: 1.240848183631897 lr : 0.000001 
Loss: 0.7225064039230347 lr : 0.000001 
Loss: 0.9000689387321472 lr : 0.000001 
Loss: 0.32082512974739075 lr : 0.000001 
Loss: 0.3496215343475342 lr : 0.000001 
Loss: 1.0830706357955933 lr : 0.000001 
Loss: 0.2834169268608093 lr : 0.000001 
Loss: 0.7999798059463501 lr : 0.000000 
Loss: 0.8917604088783264 lr : 0.000000 
L

  0%|          | 0/24231 [00:00<?, ?it/s]

Loss: 0.8207643628120422 lr : 0.000000 
Loss: 0.7596992254257202 lr : 0.000000 
Loss: 0.6613786816596985 lr : 0.000000 
Loss: 0.823258638381958 lr : 0.000000 
Loss: 0.5922842025756836 lr : 0.000000 
Loss: 0.7599189281463623 lr : 0.000000 
Loss: 0.8797274827957153 lr : 0.000000 
Loss: 0.18269093334674835 lr : 0.000000 
Loss: 0.9406559467315674 lr : 0.000000 
Loss: 0.4452696144580841 lr : 0.000000 
Loss: 1.6097608804702759 lr : 0.000000 
Loss: 0.5526441335678101 lr : 0.000000 
Loss: 0.7699441909790039 lr : 0.000000 
Loss: 0.30928170680999756 lr : 0.000000 
Loss: 0.819677472114563 lr : 0.000000 
Loss: 0.7497047781944275 lr : 0.000000 
Loss: 0.7219527959823608 lr : 0.000000 
Loss: 0.4240249991416931 lr : 0.000000 
Loss: 1.3104164600372314 lr : 0.000000 
Loss: 0.6524028182029724 lr : 0.000000 
Loss: 0.6703131794929504 lr : 0.000000 
Loss: 0.7009182572364807 lr : 0.000000 
Loss: 0.17330695688724518 lr : 0.000000 
Loss: 0.8128225207328796 lr : 0.000000 
Loss: 0.05697162076830864 lr : 0.000000

Loss: 0.692176342010498 lr : 0.000000 
Loss: 0.34508681297302246 lr : 0.000000 
Loss: 0.9469715356826782 lr : 0.000000 
Loss: 0.7209911942481995 lr : 0.000000 
Loss: 0.5152581334114075 lr : 0.000000 
Loss: 0.7375945448875427 lr : 0.000000 
Loss: 1.5237008333206177 lr : 0.000000 
Loss: 0.5869733691215515 lr : 0.000000 
Loss: 0.6620112657546997 lr : 0.000000 
Loss: 0.8475804328918457 lr : 0.000000 
Loss: 0.49475809931755066 lr : 0.000000 
Loss: 0.7436007857322693 lr : 0.000000 
Loss: 0.7389776706695557 lr : 0.000000 
Loss: 0.8566256761550903 lr : 0.000000 
Loss: 0.6119924783706665 lr : 0.000000 
Loss: 1.3368879556655884 lr : 0.000000 
Loss: 0.8153706789016724 lr : 0.000000 
Loss: 0.2263103574514389 lr : 0.000000 
Loss: 0.5985241532325745 lr : 0.000000 
Loss: 0.6460602283477783 lr : 0.000000 
Loss: 0.41680851578712463 lr : 0.000000 
Loss: 1.072177529335022 lr : 0.000000 
Loss: 0.9424703121185303 lr : 0.000000 
Loss: 0.525698721408844 lr : 0.000000 
Loss: 1.2016626596450806 lr : 0.000000 


  0%|          | 0/24231 [00:00<?, ?it/s]

Loss: 0.6509323120117188 lr : 0.000000 
Loss: 1.474658489227295 lr : 0.000000 
Loss: 0.8113728761672974 lr : 0.000000 
Loss: 0.43169093132019043 lr : 0.000000 
Loss: 0.944196343421936 lr : 0.000000 
Loss: 1.5682525634765625 lr : 0.000000 
Loss: 0.452965646982193 lr : 0.000000 
Loss: 1.1235264539718628 lr : 0.000000 
Loss: 0.5490254759788513 lr : 0.000000 
Loss: 0.18809176981449127 lr : 0.000000 
Loss: 0.20429714024066925 lr : 0.000000 
Loss: 0.7552458047866821 lr : 0.000000 
Loss: 0.6732435822486877 lr : 0.000000 
Loss: 0.3152370750904083 lr : 0.000000 
Loss: 0.4704183042049408 lr : 0.000000 
Loss: 0.9963559508323669 lr : 0.000000 
Loss: 1.170369029045105 lr : 0.000000 
Loss: 0.7466986775398254 lr : 0.000000 
Loss: 0.8500944375991821 lr : 0.000000 
Loss: 0.6573537588119507 lr : 0.000000 
Loss: 0.7682428359985352 lr : 0.000000 
Loss: 0.6529017686843872 lr : 0.000000 
Loss: 1.7234169244766235 lr : 0.000000 
Loss: 0.8686783909797668 lr : 0.000000 
Loss: 1.2084954977035522 lr : 0.000000 
L

Loss: 0.8514475226402283 lr : 0.000000 
Loss: 1.293064832687378 lr : 0.000000 
Loss: 1.0143225193023682 lr : 0.000000 
Loss: 1.0422883033752441 lr : 0.000000 
Loss: 1.2476650476455688 lr : 0.000000 
Loss: 0.24910324811935425 lr : 0.000000 
Loss: 0.7972142696380615 lr : 0.000000 
Loss: 1.1901086568832397 lr : 0.000000 
Loss: 0.42412322759628296 lr : 0.000000 
Loss: 1.1317559480667114 lr : 0.000000 
Loss: 0.6861850619316101 lr : 0.000000 
Loss: 0.39656612277030945 lr : 0.000000 
Loss: 0.8621700406074524 lr : 0.000000 
Loss: 0.7056178450584412 lr : 0.000000 
Loss: 0.9518454074859619 lr : 0.000000 
Loss: 0.9984250664710999 lr : 0.000000 
Loss: 0.5631944537162781 lr : 0.000000 
Loss: 0.777550458908081 lr : 0.000000 
Loss: 0.31235307455062866 lr : 0.000000 
Loss: 0.9698060750961304 lr : 0.000000 
Loss: 0.6368464231491089 lr : 0.000000 
Loss: 0.47314193844795227 lr : 0.000000 
Loss: 0.7375641465187073 lr : 0.000000 
Loss: 0.8368582725524902 lr : 0.000000 
Loss: 0.2975318729877472 lr : 0.00000

  0%|          | 0/24231 [00:00<?, ?it/s]

Loss: 0.822554886341095 lr : 0.000000 
Loss: 1.1946192979812622 lr : 0.000000 
Loss: 0.8359084725379944 lr : 0.000000 
Loss: 0.6560397148132324 lr : 0.000000 
Loss: 0.8776852488517761 lr : 0.000000 
Loss: 0.635638952255249 lr : 0.000000 
Loss: 0.6803064942359924 lr : 0.000000 
Loss: 0.6456636190414429 lr : 0.000000 
Loss: 0.792097806930542 lr : 0.000000 
Loss: 0.6495094895362854 lr : 0.000000 
Loss: 0.33406782150268555 lr : 0.000000 
Loss: 0.8484772443771362 lr : 0.000000 
Loss: 0.6731852293014526 lr : 0.000000 
Loss: 0.8136777877807617 lr : 0.000000 
Loss: 0.24912846088409424 lr : 0.000000 
Loss: 1.0719871520996094 lr : 0.000000 
Loss: 0.6304479837417603 lr : 0.000000 
Loss: 0.7944542169570923 lr : 0.000000 
Loss: 1.0570679903030396 lr : 0.000000 
Loss: 1.469242811203003 lr : 0.000000 
Loss: 0.15994170308113098 lr : 0.000000 
Loss: 0.8354843258857727 lr : 0.000000 
Loss: 1.0803004503250122 lr : 0.000000 
Loss: 0.3736814856529236 lr : 0.000000 
Loss: 0.4584267735481262 lr : 0.000000 
L

Loss: 0.26986783742904663 lr : 0.000001 
Loss: 0.7213385105133057 lr : 0.000001 
Loss: 1.2187060117721558 lr : 0.000001 
Loss: 0.7964249849319458 lr : 0.000001 
Loss: 0.6108760237693787 lr : 0.000001 
Loss: 0.7462961077690125 lr : 0.000001 
Loss: 0.7796894907951355 lr : 0.000001 
Loss: 1.0889841318130493 lr : 0.000001 
Loss: 0.2672808766365051 lr : 0.000001 
Loss: 0.35482853651046753 lr : 0.000001 
Loss: 0.9952284097671509 lr : 0.000001 
Loss: 1.1094893217086792 lr : 0.000001 
Loss: 0.2425563782453537 lr : 0.000001 
Loss: 1.1115297079086304 lr : 0.000001 
Loss: 1.2193368673324585 lr : 0.000001 
Loss: 1.014015555381775 lr : 0.000001 
Loss: 0.7908509373664856 lr : 0.000001 
Loss: 1.1475123167037964 lr : 0.000001 
Loss: 0.6541234254837036 lr : 0.000001 
Loss: 1.2704393863677979 lr : 0.000001 
Loss: 0.9871131777763367 lr : 0.000001 
Loss: 0.9672818183898926 lr : 0.000001 
Loss: 0.6557827591896057 lr : 0.000001 
Loss: 1.2765076160430908 lr : 0.000001 
Loss: 0.8192527890205383 lr : 0.000001 