In [2]:
%env CUDA_VISIBLE_DEVICES=0
%env TRANSFORMERS_CACHE=/mnt/LLM/hub
%env OMP_NUM_THREADS=16

import os
import sys
sys.path.insert(0, '..')

import time
import random
from tqdm.auto import trange
import ipynbname  # pip install ipynbname

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

from src.aq import QuantizedWeight, _reconstruct_weight  # see adjacent file (aq.py)
from src.utils import calc_avg_bits, get_mean_nbits_by_codebook  # see adjacent file (aq.py)

torch.set_num_threads(16)
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_loading_dir = '/extra_disk_1/vahe1994/BRRR/layer10.self_attn.q_proj.input_activation.pt'
num_codebooks = 1
nbits_per_codebook = 14
out_group_size = 1
in_group_size = 8
batch_size = 16384
beam_size = 1
beam_search_epochs = 100
sparsity_regularizer = 0
print_frequency = 10
scale_nbits = 2    # 0 means no scales, 16 means no compression;


#for each group we store num_codebooks * nbits_per_codebook - bits 
# for W matrix  we store out_features*(in_features // group_size) * num_codebooks * nbits_per_codebook - bits 
# 1 codebook store codebook_size*group_size*16 - bit 
# all codebooks store num_codebooks* codebook_size*group_size*16 -bits
in_features, out_features  = 8192, 8192
estimated_bits_per_param = calc_avg_bits(
    num_codebooks, out_group_size, in_group_size, nbits_per_codebook, in_features, out_features, scale_nbits)
print("Estimated bits / param", estimated_bits_per_param)

env: CUDA_VISIBLE_DEVICES=0
env: TRANSFORMERS_CACHE=/mnt/LLM/hub
env: OMP_NUM_THREADS=16
Estimated bits / param 2.0390625


In [3]:
import wandb

os.environ["WANDB_NOTEBOOK_NAME"] = os.path.join(os.getcwd(), ipynbname.name() + ".ipynb")

# start a new wandb run to track this script
run = wandb.init(
    # set the wandb project where this run will be logged
    dir=os.getcwd(),
    project="AddQuantization",
    entity = "rock-and-roll",
    save_code=True,
    name = f"{ipynbname.name()}_AQ_{num_codebooks=}_{out_group_size=}_{in_group_size=}_{nbits_per_codebook=}_{beam_search_epochs=}",
    settings=wandb.Settings(code_dir="."),
    # track hyperparameters and run metadata
    config={
    "num_codebooks" : num_codebooks,
    "out_group_size": out_group_size,
    "in_group_size": in_group_size,
    "group_size" : out_group_size * in_group_size,
    "batch_size" : batch_size,
    "beam_size" : beam_size,
    "nbits_per_codebook" : nbits_per_codebook,
    "Avg_bits": estimated_bits_per_param,
    "beam_search_epochs": beam_search_epochs,
    "sparsity_regularizer": sparsity_regularizer,
    "scale_nvits": scale_nbits,
    }
)
run.log({"Avg_bits": estimated_bits_per_param})

[34m[1mwandb[0m: Currently logged in as: [33mjustheuristic[0m ([33mrock-and-roll[0m). Use [1m`wandb login --relogin`[0m to force relogin




In [4]:
model = transformers.AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf", torch_dtype='auto', low_cpu_mem_usage=True)

X = torch.load(input_loading_dir, map_location='cpu').float().flatten(0, -2)
reference_weight = model.model.layers[10].self_attn.q_proj.weight.detach().to(device).float()

XTX = torch.zeros(X.shape[-1], X.shape[-1], device=device, dtype=torch.float64)
for i in range(0, len(X), batch_size):
    x_batch = X[i: i + batch_size].cuda().double()
    XTX.addmm_(x_batch.T, x_batch, alpha=1/len(X))
    del x_batch
XTX = XTX.float()
del X

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

In [None]:
quantized_layer = QuantizedWeight(
    reference_weight=reference_weight, num_codebooks=num_codebooks,
    nbits_per_codebook=nbits_per_codebook, scale_nbits=scale_nbits,
    out_group_size=out_group_size, in_group_size=in_group_size,
    max_iter=100,   # faster init, not tested
    verbose=True
)
opt = torch.optim.Adam(quantized_layer.parameters(), lr=1e-4, betas=(0.0, 0.95), amsgrad=True)

for epoch in range(10_000):
    start = time.perf_counter()
    reconstructed_weight = quantized_weight()
    delta_weight = (reconstructed_weight - reference_weight).double()
    loss = (delta_weight @ XTX.double()).flatten() @ delta_weight.flatten() / len(delta_weight)
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    run.log({'loss':loss.item()}, step=epoch)
    
    if epoch % print_frequency == 0:
        print(f"loss={loss.item():.10f}\t",
              f"time_on_epoch {epoch} = {time.perf_counter() - start}")
    if (epoch + 1) % beam_search_epochs == 0:
        quantized_layer.requantize_(
            XTX, reference_weight,
            beam_size=beam_size,
            sparsity_regularizer=sparsity_regularizer,  # tip: use const_hparam * quantized_layer.codes.numel()
            dim_rng=random.Random(), verbose=True)

        if sparsity_regularizer != 0:
            sparsity_rate = ((quantized_layer.codes == 0).sum() / quantized_layer.codes.numel()).item()
            print(f"Sparsity rate {sparsity_rate:.5f}")
            run.log({'sparsity rate': sparsity_rate}, step=epoch)
            mean_code_nbits = sum(get_mean_nbits_by_codebook(quantized_layer.codes)) / num_codebooks
            print(f"mean_code_nbits {mean_code_nbits:.5f}")
            run.log({'Mean codebook length nbits': mean_code_nbits}, step=epoch)
            if in_group_size > 1 and out_group_size > 1:
                curr_avg_bits  = calc_avg_bits(num_codebooks, 1, mean_code_nbits,
                                     nbits_per_codebook, in_features, out_features, scale_nbits)
                run.log({"Avg_bits": curr_avg_bits}, step=epoch)

  groupwise_cluster_indices = torch.searchsorted(border_indices[:, 1:], groupwise_ranks_1based, side='left')


initializing with kmeans:   0%|          | 0/1 [00:00<?, ?it/s]

  codebook_i, codes_i, reconstructed_weight_i = fit_kmeans(weight_residue, k=codebook_size, **kwargs)


loss=0.0258857615	 time_on_epoch 0 = 1.794874879065901
loss=0.0214160136	 time_on_epoch 10 = 0.13309287303127348
loss=0.0179490415	 time_on_epoch 20 = 0.13319369300734252
loss=0.0152740230	 time_on_epoch 30 = 0.13324946200009435
loss=0.0132520099	 time_on_epoch 40 = 0.1333391530206427
loss=0.0117486446	 time_on_epoch 50 = 0.13315503299236298
loss=0.0106320357	 time_on_epoch 60 = 0.13317235291469842
loss=0.0097964607	 time_on_epoch 70 = 0.13332686305511743
loss=0.0091640784	 time_on_epoch 80 = 0.13326736295130104
loss=0.0086792293	 time_on_epoch 90 = 0.13343094300944358


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0035189353	 time_on_epoch 100 = 0.1336318530375138
loss=0.0035091678	 time_on_epoch 110 = 0.13316636299714446
loss=0.0035043052	 time_on_epoch 120 = 0.1333161829970777
loss=0.0035010309	 time_on_epoch 130 = 0.13334985298570246
loss=0.0034985890	 time_on_epoch 140 = 0.1332549329381436
loss=0.0034966450	 time_on_epoch 150 = 0.1332929430063814
loss=0.0034950249	 time_on_epoch 160 = 0.13314009306486696
loss=0.0034936290	 time_on_epoch 170 = 0.1331540320534259
loss=0.0034923968	 time_on_epoch 180 = 0.13312971289269626
loss=0.0034912889	 time_on_epoch 190 = 0.1332523429300636


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0027500544	 time_on_epoch 200 = 0.1337201700080186
loss=0.0027436652	 time_on_epoch 210 = 0.1331480590160936
loss=0.0027403748	 time_on_epoch 220 = 0.1331907599233091
loss=0.0027380148	 time_on_epoch 230 = 0.13311777007766068
loss=0.0027361537	 time_on_epoch 240 = 0.1331331899855286
loss=0.0027346007	 time_on_epoch 250 = 0.1332436390221119
loss=0.0027332553	 time_on_epoch 260 = 0.13381329004187137
loss=0.0027320593	 time_on_epoch 270 = 0.13322458998300135
loss=0.0027309766	 time_on_epoch 280 = 0.13325519894715399
loss=0.0027299832	 time_on_epoch 290 = 0.13342676998581737


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0025152603	 time_on_epoch 300 = 0.1336500939214602
loss=0.0025119187	 time_on_epoch 310 = 0.13334613398183137
loss=0.0025101011	 time_on_epoch 320 = 0.13310248404741287
loss=0.0025087444	 time_on_epoch 330 = 0.13309371401555836
loss=0.0025076410	 time_on_epoch 340 = 0.13311722502112389
loss=0.0025066980	 time_on_epoch 350 = 0.13310846395324916
loss=0.0025058657	 time_on_epoch 360 = 0.13310975406784564
loss=0.0025051146	 time_on_epoch 370 = 0.13321715500205755
loss=0.0025044264	 time_on_epoch 380 = 0.13321204402018338
loss=0.0025037887	 time_on_epoch 390 = 0.13316425494849682


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0024054532	 time_on_epoch 400 = 0.13364995794836432
loss=0.0024032882	 time_on_epoch 410 = 0.13322594796773046
loss=0.0024020441	 time_on_epoch 420 = 0.13320283708162606
loss=0.0024010908	 time_on_epoch 430 = 0.1335150880040601
loss=0.0024002998	 time_on_epoch 440 = 0.13342391804326326
loss=0.0023996128	 time_on_epoch 450 = 0.13312365801539272
loss=0.0023989984	 time_on_epoch 460 = 0.13316279801074415
loss=0.0023984378	 time_on_epoch 470 = 0.13341390795540065
loss=0.0023979193	 time_on_epoch 480 = 0.1331841079518199
loss=0.0023974349	 time_on_epoch 490 = 0.13317982794251293


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0023415080	 time_on_epoch 500 = 0.13365037995390594
loss=0.0023399954	 time_on_epoch 510 = 0.13321528991218656
loss=0.0023390924	 time_on_epoch 520 = 0.13326122995931655
loss=0.0023383842	 time_on_epoch 530 = 0.13333903101738542
loss=0.0023377855	 time_on_epoch 540 = 0.1331257000565529
loss=0.0023372576	 time_on_epoch 550 = 0.13321992999408394
loss=0.0023367795	 time_on_epoch 560 = 0.1331358200404793
loss=0.0023363389	 time_on_epoch 570 = 0.13315209094434977
loss=0.0023359280	 time_on_epoch 580 = 0.1331702710594982
loss=0.0023355413	 time_on_epoch 590 = 0.13321200106292963


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0023001399	 time_on_epoch 600 = 0.1335208520758897
loss=0.0022990120	 time_on_epoch 610 = 0.13317271205596626
loss=0.0022983210	 time_on_epoch 620 = 0.13321509200613946
loss=0.0022977681	 time_on_epoch 630 = 0.13314453209750354
loss=0.0022972933	 time_on_epoch 640 = 0.13319827197119594
loss=0.0022968691	 time_on_epoch 650 = 0.13315201201476157
loss=0.0022964810	 time_on_epoch 660 = 0.13329762301873416
loss=0.0022961202	 time_on_epoch 670 = 0.13323536200914532
loss=0.0022957811	 time_on_epoch 680 = 0.13341924210544676
loss=0.0022954601	 time_on_epoch 690 = 0.13352238200604916


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0022710219	 time_on_epoch 700 = 0.13358489295933396
loss=0.0022701288	 time_on_epoch 710 = 0.13320212403777987
loss=0.0022695717	 time_on_epoch 720 = 0.13316141394898295
loss=0.0022691209	 time_on_epoch 730 = 0.13321165298111737
loss=0.0022687301	 time_on_epoch 740 = 0.13317778299096972
loss=0.0022683783	 time_on_epoch 750 = 0.13321151304990053
loss=0.0022680542	 time_on_epoch 760 = 0.13343016302678734
loss=0.0022677513	 time_on_epoch 770 = 0.13316466298419982
loss=0.0022674654	 time_on_epoch 780 = 0.1331083329860121
loss=0.0022671934	 time_on_epoch 790 = 0.1331382729113102


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0022491926	 time_on_epoch 800 = 0.1335347139975056
loss=0.0022484702	 time_on_epoch 810 = 0.1332423340063542
loss=0.0022480085	 time_on_epoch 820 = 0.13339278392959386
loss=0.0022476301	 time_on_epoch 830 = 0.13322045397944748
loss=0.0022472990	 time_on_epoch 840 = 0.1334228339837864
loss=0.0022469987	 time_on_epoch 850 = 0.1333410240476951
loss=0.0022467205	 time_on_epoch 860 = 0.1331519439117983
loss=0.0022464591	 time_on_epoch 870 = 0.13314482406713068
loss=0.0022462113	 time_on_epoch 880 = 0.13316596404183656
loss=0.0022459747	 time_on_epoch 890 = 0.13312428398057818


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0022322086	 time_on_epoch 900 = 0.13362795498687774
loss=0.0022315961	 time_on_epoch 910 = 0.13307139498647302
loss=0.0022312020	 time_on_epoch 920 = 0.1331008040579036
loss=0.0022308772	 time_on_epoch 930 = 0.13308523502200842
loss=0.0022305914	 time_on_epoch 940 = 0.13314180402085185
loss=0.0022303311	 time_on_epoch 950 = 0.13348336494527757
loss=0.0022300890	 time_on_epoch 960 = 0.13317679497413337
loss=0.0022298607	 time_on_epoch 970 = 0.13314921397250146
loss=0.0022296436	 time_on_epoch 980 = 0.13315447408240288
loss=0.0022294357	 time_on_epoch 990 = 0.13317280495539308


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0022185609	 time_on_epoch 1000 = 0.13355160504579544
loss=0.0022180553	 time_on_epoch 1010 = 0.13311437505763024
loss=0.0022177219	 time_on_epoch 1020 = 0.1332146949134767
loss=0.0022174432	 time_on_epoch 1030 = 0.13317133497912437
loss=0.0022171954	 time_on_epoch 1040 = 0.13313246506731957
loss=0.0022169680	 time_on_epoch 1050 = 0.1331910250009969
loss=0.0022167552	 time_on_epoch 1060 = 0.13321296509820968
loss=0.0022165536	 time_on_epoch 1070 = 0.13320566504262388
loss=0.0022163612	 time_on_epoch 1080 = 0.13329338503535837
loss=0.0022161764	 time_on_epoch 1090 = 0.1332855150103569


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0022073213	 time_on_epoch 1100 = 0.13361795898526907
loss=0.0022068804	 time_on_epoch 1110 = 0.13313622993882746
loss=0.0022065871	 time_on_epoch 1120 = 0.13325011101551354
loss=0.0022063407	 time_on_epoch 1130 = 0.13311923190485686
loss=0.0022061207	 time_on_epoch 1140 = 0.1331096920184791
loss=0.0022059181	 time_on_epoch 1150 = 0.1331365619553253
loss=0.0022057280	 time_on_epoch 1160 = 0.13330918306019157
loss=0.0022055475	 time_on_epoch 1170 = 0.13315208395943046
loss=0.0022053748	 time_on_epoch 1180 = 0.13319097505882382
loss=0.0022052086	 time_on_epoch 1190 = 0.13313711597584188


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0021978460	 time_on_epoch 1200 = 0.13365623797290027
loss=0.0021974600	 time_on_epoch 1210 = 0.13316985894925892
loss=0.0021972003	 time_on_epoch 1220 = 0.1332180789904669
loss=0.0021969806	 time_on_epoch 1230 = 0.13325959001667798
loss=0.0021967835	 time_on_epoch 1240 = 0.13318252109456807
loss=0.0021966013	 time_on_epoch 1250 = 0.13354516099207103
loss=0.0021964298	 time_on_epoch 1260 = 0.1334548619342968
loss=0.0021962664	 time_on_epoch 1270 = 0.13369149109348655
loss=0.0021961098	 time_on_epoch 1280 = 0.1336745120352134
loss=0.0021959589	 time_on_epoch 1290 = 0.13347412296570837


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0021897412	 time_on_epoch 1300 = 0.13356291293166578
loss=0.0021893958	 time_on_epoch 1310 = 0.13311242498457432
loss=0.0021891611	 time_on_epoch 1320 = 0.1331697249552235
loss=0.0021889619	 time_on_epoch 1330 = 0.13321628503035754
loss=0.0021887826	 time_on_epoch 1340 = 0.13320754596497864
loss=0.0021886164	 time_on_epoch 1350 = 0.13324708596337587
loss=0.0021884597	 time_on_epoch 1360 = 0.13321646605618298
loss=0.0021883102	 time_on_epoch 1370 = 0.13314562698360533
loss=0.0021881667	 time_on_epoch 1380 = 0.133235716028139
loss=0.0021880281	 time_on_epoch 1390 = 0.13316802692133933


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0021827188	 time_on_epoch 1400 = 0.13355604303069413
loss=0.0021824076	 time_on_epoch 1410 = 0.13313878397457302
loss=0.0021821959	 time_on_epoch 1420 = 0.13312306394800544
loss=0.0021820151	 time_on_epoch 1430 = 0.13310532295145094
loss=0.0021818518	 time_on_epoch 1440 = 0.13315116404555738
loss=0.0021817000	 time_on_epoch 1450 = 0.13314099400304258
loss=0.0021815564	 time_on_epoch 1460 = 0.13317281391937286
loss=0.0021814193	 time_on_epoch 1470 = 0.1331460940418765
loss=0.0021812873	 time_on_epoch 1480 = 0.13315802498254925
loss=0.0021811597	 time_on_epoch 1490 = 0.1331395850284025


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0021764935	 time_on_epoch 1500 = 0.13353679003193974
loss=0.0021762118	 time_on_epoch 1510 = 0.13341454102192074
loss=0.0021760177	 time_on_epoch 1520 = 0.13310204003937542
loss=0.0021758516	 time_on_epoch 1530 = 0.13317896099761128
loss=0.0021757013	 time_on_epoch 1540 = 0.13321532006375492
loss=0.0021755615	 time_on_epoch 1550 = 0.13314036093652248
loss=0.0021754292	 time_on_epoch 1560 = 0.1333754409570247
loss=0.0021753027	 time_on_epoch 1570 = 0.13323663198389113
loss=0.0021751808	 time_on_epoch 1580 = 0.1331708210054785
loss=0.0021750629	 time_on_epoch 1590 = 0.13324995199218392


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0021709253	 time_on_epoch 1600 = 0.13372634898405522
loss=0.0021706695	 time_on_epoch 1610 = 0.13313379895407706
loss=0.0021704931	 time_on_epoch 1620 = 0.13318433007225394
loss=0.0021703413	 time_on_epoch 1630 = 0.13313273002859205
loss=0.0021702033	 time_on_epoch 1640 = 0.13316781003959477
loss=0.0021700745	 time_on_epoch 1650 = 0.13336658000480384
loss=0.0021699522	 time_on_epoch 1660 = 0.1332062700530514
loss=0.0021698349	 time_on_epoch 1670 = 0.1332719400525093
loss=0.0021697219	 time_on_epoch 1680 = 0.13327125005889684
loss=0.0021696123	 time_on_epoch 1690 = 0.13312160992063582


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0021659785	 time_on_epoch 1700 = 0.13364483299665153
loss=0.0021657455	 time_on_epoch 1710 = 0.13317271298728883
loss=0.0021655839	 time_on_epoch 1720 = 0.13316851295530796
loss=0.0021654444	 time_on_epoch 1730 = 0.1332132830284536
loss=0.0021653173	 time_on_epoch 1740 = 0.13350400293711573
loss=0.0021651985	 time_on_epoch 1750 = 0.13319340301677585
loss=0.0021650856	 time_on_epoch 1760 = 0.13316793297417462
loss=0.0021649772	 time_on_epoch 1770 = 0.1332595330895856
loss=0.0021648725	 time_on_epoch 1780 = 0.13312513404525816
loss=0.0021647710	 time_on_epoch 1790 = 0.13325134303886443


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0021615650	 time_on_epoch 1800 = 0.1337555010104552
loss=0.0021613521	 time_on_epoch 1810 = 0.1331911919405684
loss=0.0021612034	 time_on_epoch 1820 = 0.1331270820228383
loss=0.0021610746	 time_on_epoch 1830 = 0.13316589198075235
loss=0.0021609571	 time_on_epoch 1840 = 0.13330033200327307
loss=0.0021608470	 time_on_epoch 1850 = 0.13335105206351727
loss=0.0021607421	 time_on_epoch 1860 = 0.1330703420098871
loss=0.0021606413	 time_on_epoch 1870 = 0.13324584299698472
loss=0.0021605439	 time_on_epoch 1880 = 0.13321190199349076
loss=0.0021604493	 time_on_epoch 1890 = 0.13352833304088563


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0021575651	 time_on_epoch 1900 = 0.1335948979249224
loss=0.0021573686	 time_on_epoch 1910 = 0.1335037989774719
loss=0.0021572306	 time_on_epoch 1920 = 0.1331809579860419
loss=0.0021571106	 time_on_epoch 1930 = 0.13313971902243793
loss=0.0021570007	 time_on_epoch 1940 = 0.13333715789485723
loss=0.0021568976	 time_on_epoch 1950 = 0.1332860990660265
loss=0.0021567992	 time_on_epoch 1960 = 0.13330545893404633
loss=0.0021567046	 time_on_epoch 1970 = 0.13318547897506505
loss=0.0021566131	 time_on_epoch 1980 = 0.13314512895885855
loss=0.0021565241	 time_on_epoch 1990 = 0.13316450896672904


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0021539355	 time_on_epoch 2000 = 0.13354079297278076
loss=0.0021537524	 time_on_epoch 2010 = 0.13318546302616596
loss=0.0021536233	 time_on_epoch 2020 = 0.13371747301425785
loss=0.0021535109	 time_on_epoch 2030 = 0.13312378292903304
loss=0.0021534079	 time_on_epoch 2040 = 0.13319656299427152
loss=0.0021533111	 time_on_epoch 2050 = 0.13315940299071372
loss=0.0021532188	 time_on_epoch 2060 = 0.13317721302155405
loss=0.0021531298	 time_on_epoch 2070 = 0.13319819292519242
loss=0.0021530437	 time_on_epoch 2080 = 0.13323167408816516
loss=0.0021529600	 time_on_epoch 2090 = 0.13312618399504572


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0021505469	 time_on_epoch 2100 = 0.13366506597958505
loss=0.0021503773	 time_on_epoch 2110 = 0.13322416704613715
loss=0.0021502564	 time_on_epoch 2120 = 0.13312692707404494
loss=0.0021501509	 time_on_epoch 2130 = 0.13316607708111405
loss=0.0021500539	 time_on_epoch 2140 = 0.1331694460241124
loss=0.0021499626	 time_on_epoch 2150 = 0.13314664701465517
loss=0.0021498754	 time_on_epoch 2160 = 0.13316499697975814
loss=0.0021497913	 time_on_epoch 2170 = 0.1332580370362848
loss=0.0021497098	 time_on_epoch 2180 = 0.13312629703432322
loss=0.0021496306	 time_on_epoch 2190 = 0.13338325603399426


  0%|          | 0/1024 [00:00<?, ?it/s]

loss=0.0021474483	 time_on_epoch 2200 = 0.1337250379147008
