In [12]:
!pip install simclr

[0m

In [13]:
import torch
import torchvision

import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms

from simclr.modules.transformations import TransformsSimCLR
from simclr.modules import NT_Xent
from simclr.modules.identity import Identity

from torch.utils.data.dataloader import DataLoader

import gc
gc.collect()
torch.cuda.empty_cache()

In [14]:
class SimCLR(nn.Module):

    def __init__(self, encoder, projection_dim, n_features):
        super(SimCLR, self).__init__()

        self.encoder = encoder
        self.n_features = n_features

        self.encoder.fc = Identity()
        
        self.projector = nn.Sequential(
            nn.Linear(self.n_features, self.n_features, bias=False, device = DEVICE),
            nn.ReLU(),
            nn.Linear(self.n_features, self.n_features, bias=False, device = DEVICE),
            nn.ReLU(),
        )

    def forward(self, x_i, x_j):
        
        
        h_i = self.encoder(x_i)
        h_j = self.encoder(x_j)
        
        z_i = self.projector(h_i)
        z_j = self.projector(h_j)
        
        x_i = self.encoder.conv1(x_i)
        x_i = self.encoder.bn1(x_i)
        x_i = self.encoder.relu(x_i)
        x_i = self.encoder.maxpool(x_i)
        
        
        #Extract the features from intermediate layers
        h_1 = self.encoder.layer1(x_i)
        h_2 = self.encoder.layer2(h_1)
        h_3 = self.encoder.layer3(h_2)
        h_4 = self.encoder.layer4(h_3)
        
        h_1.norm(dim = 1, p = 2)
        h_2.norm(dim = 1, p = 2)
        h_3.norm(dim = 1, p = 2)
        h_4.norm(dim = 1, p = 2)
        
        return h_i, h_j, z_i, z_j, h_1, h_2, h_3, h_4

In [15]:
image_size = 14
batch_size = 128
projection_dim = 64 #128 for imagenette, 64 for CIFAR10

In [16]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
tim_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor()
])

train_transform = TransformsSimCLR(size=image_size)
test_transform = tim_transform

train_dataset = torchvision.datasets.CIFAR10(root = '/content/sample_data', transform = train_transform, download = True)
test_dataset = torchvision.datasets.CIFAR10(root = '/content/sample_data', transform = test_transform)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = batch_size,
    shuffle = False,
    drop_last=True,
    num_workers = 2,
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle = False,
    drop_last=True,
    num_workers = 2,
)

Files already downloaded and verified


In [18]:
encoder = torchvision.models.resnet18()
n_features = encoder.fc.in_features  # get dimensions of last fully-connected layer
model = SimCLR(encoder, projection_dim, n_features).to(DEVICE)

In [19]:
import os

from simclr.modules import LARS


def load_optimizer(optimizer, model, batch_size, weight_decay, epochs):
    if optimizer == "Adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)  # TODO: LARS
    elif optimizer == "LARS":
        learning_rate = 0.3 * batch_size / 256
        optimizer = LARS(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
            exclude_from_weight_decay=["batch_normalization", "bias"],
        )

        # "decay the learning rate with the cosine decay schedule without restarts"
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, epochs, eta_min=0, last_epoch=-1
        )
    else:
        raise NotImplementedError

    return optimizer, scheduler


def save_model(current_epoch, model, optimizer, model_path):
    out = os.path.join(model_path, "checkpoint_{}.tar".format(current_epoch))

    if isinstance(model, torch.nn.DataParallel):
        torch.save(model.module.state_dict(), out)
    else:
        torch.save(model.state_dict(), out)


In [20]:
from simclr.modules import NT_Xent
from tqdm import tqdm

weight_decay = 1e-6
epochs = 100
start_epoch = 0
model_path = '/kaggle/working'

optimizer, scheduler = load_optimizer('LARS', model, batch_size, weight_decay, epochs)
criterion = NT_Xent(batch_size, temperature = 0.5, world_size = 1)

In [21]:
def train(train_loader, model, criterion, optimizer, global_step):
    loss_epoch = 0
    for step, ((x_i, x_j), _) in enumerate(train_loader):
        optimizer.zero_grad()
        x_i = x_i.cuda()
        x_j = x_j.cuda()

        # positive pair, with encoding
        h_i, h_j, z_i, z_j, _, _, _, _ = model(x_i, x_j)
        
        loss = criterion(z_i, z_j)
        loss.backward()

        optimizer.step()

        if step % 50 == 0:
            print(f"Step [{step}/{len(train_loader)}]\t Loss: {loss.item()}")

        loss_epoch += loss.item()
        global_step += 1
    return loss_epoch

In [22]:
global_step = 0
current_epoch = 0
for epoch in tqdm(range(start_epoch, epochs)):
    lr = optimizer.param_groups[0]["lr"]
    loss_epoch = train(train_loader, model, criterion, optimizer, global_step)

    if scheduler:
        scheduler.step()

    # save every 10 epochs
    if epoch % 10 == 0:
        save_model(current_epoch, model, optimizer, model_path)

    print(
        f"Epoch [{epoch}/{epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}"
    )
    current_epoch += 1

# end training
save_model(current_epoch, model, optimizer, model_path)

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /usr/local/src/pytorch/torch/csrc/utils/python_arg_parser.cpp:1055.)
  next_v.mul_(momentum).add_(scaled_lr, grad)


Step [0/390]	 Loss: 5.530717372894287
Step [50/390]	 Loss: 5.477902412414551
Step [100/390]	 Loss: 5.276640892028809
Step [150/390]	 Loss: 5.25023889541626
Step [200/390]	 Loss: 5.195037364959717
Step [250/390]	 Loss: 5.123541831970215
Step [300/390]	 Loss: 5.110691070556641
Step [350/390]	 Loss: 5.073017597198486


  1%|          | 1/100 [01:19<2:10:27, 79.07s/it]

Epoch [0/100]	 Loss: 5.261579210330279	 lr: 0.15
Step [0/390]	 Loss: 5.2655792236328125
Step [50/390]	 Loss: 5.166264533996582
Step [100/390]	 Loss: 5.018362998962402
Step [150/390]	 Loss: 5.161472320556641
Step [200/390]	 Loss: 5.114624500274658
Step [250/390]	 Loss: 5.095117092132568
Step [300/390]	 Loss: 5.049841403961182
Step [350/390]	 Loss: 5.027430534362793


  2%|▏         | 2/100 [02:38<2:09:31, 79.30s/it]

Epoch [1/100]	 Loss: 5.0841063242692215	 lr: 0.14996
Step [0/390]	 Loss: 5.101804733276367
Step [50/390]	 Loss: 5.100811958312988
Step [100/390]	 Loss: 4.925225257873535
Step [150/390]	 Loss: 5.158955097198486
Step [200/390]	 Loss: 5.060067653656006
Step [250/390]	 Loss: 5.058027744293213
Step [300/390]	 Loss: 5.047232627868652
Step [350/390]	 Loss: 4.974496364593506


  3%|▎         | 3/100 [03:56<2:07:21, 78.78s/it]

Epoch [2/100]	 Loss: 5.041183125667083	 lr: 0.14985
Step [0/390]	 Loss: 5.049844264984131
Step [50/390]	 Loss: 5.052509784698486
Step [100/390]	 Loss: 5.0675153732299805
Step [150/390]	 Loss: 5.019184112548828
Step [200/390]	 Loss: 5.024435520172119
Step [250/390]	 Loss: 4.938823699951172
Step [300/390]	 Loss: 5.0369086265563965
Step [350/390]	 Loss: 4.907812118530273


  4%|▍         | 4/100 [05:16<2:06:46, 79.23s/it]

Epoch [3/100]	 Loss: 5.008063716154832	 lr: 0.14967
Step [0/390]	 Loss: 5.000146389007568
Step [50/390]	 Loss: 4.954580783843994
Step [100/390]	 Loss: 4.94312047958374
Step [150/390]	 Loss: 5.021020412445068
Step [200/390]	 Loss: 4.913854122161865
Step [250/390]	 Loss: 4.933616638183594
Step [300/390]	 Loss: 5.004793643951416
Step [350/390]	 Loss: 4.89621639251709


  5%|▌         | 5/100 [06:35<2:05:25, 79.22s/it]

Epoch [4/100]	 Loss: 4.986928450755584	 lr: 0.14941
Step [0/390]	 Loss: 4.996393203735352
Step [50/390]	 Loss: 5.107881546020508
Step [100/390]	 Loss: 4.867859840393066
Step [150/390]	 Loss: 5.041482925415039
Step [200/390]	 Loss: 4.964962005615234
Step [250/390]	 Loss: 5.067287921905518
Step [300/390]	 Loss: 4.994883060455322
Step [350/390]	 Loss: 4.934821128845215


  6%|▌         | 6/100 [07:55<2:04:36, 79.54s/it]

Epoch [5/100]	 Loss: 4.975134795751327	 lr: 0.14908
Step [0/390]	 Loss: 4.921237468719482
Step [50/390]	 Loss: 4.964050769805908
Step [100/390]	 Loss: 4.9307332038879395
Step [150/390]	 Loss: 4.988417148590088
Step [200/390]	 Loss: 4.868237495422363
Step [250/390]	 Loss: 4.990555763244629
Step [300/390]	 Loss: 4.919993877410889
Step [350/390]	 Loss: 4.8377909660339355


  7%|▋         | 7/100 [09:14<2:03:01, 79.37s/it]

Epoch [6/100]	 Loss: 4.9613927547748276	 lr: 0.14867
Step [0/390]	 Loss: 4.971165657043457
Step [50/390]	 Loss: 4.947046279907227
Step [100/390]	 Loss: 4.941020965576172
Step [150/390]	 Loss: 5.011760234832764
Step [200/390]	 Loss: 4.924790382385254
Step [250/390]	 Loss: 4.920597553253174
Step [300/390]	 Loss: 4.933625221252441
Step [350/390]	 Loss: 5.001681327819824


  8%|▊         | 8/100 [10:35<2:02:21, 79.79s/it]

Epoch [7/100]	 Loss: 4.942234587058043	 lr: 0.14819
Step [0/390]	 Loss: 4.914772033691406
Step [50/390]	 Loss: 4.974632740020752
Step [100/390]	 Loss: 4.943131923675537
Step [150/390]	 Loss: 4.917837142944336
Step [200/390]	 Loss: 4.9794392585754395
Step [250/390]	 Loss: 4.912066459655762
Step [300/390]	 Loss: 4.8995137214660645
Step [350/390]	 Loss: 4.9485392570495605


  9%|▉         | 9/100 [11:54<2:00:44, 79.61s/it]

Epoch [8/100]	 Loss: 4.9342997025220825	 lr: 0.14764
Step [0/390]	 Loss: 4.917328357696533
Step [50/390]	 Loss: 4.954358100891113
Step [100/390]	 Loss: 4.850672721862793
Step [150/390]	 Loss: 5.007036209106445
Step [200/390]	 Loss: 4.8655548095703125
Step [250/390]	 Loss: 4.960180759429932
Step [300/390]	 Loss: 4.896512508392334
Step [350/390]	 Loss: 4.910886287689209


 10%|█         | 10/100 [13:14<1:59:35, 79.73s/it]

Epoch [9/100]	 Loss: 4.921820761607243	 lr: 0.14702
Step [0/390]	 Loss: 4.886284828186035
Step [50/390]	 Loss: 4.989652633666992
Step [100/390]	 Loss: 4.889413356781006
Step [150/390]	 Loss: 4.9655303955078125
Step [200/390]	 Loss: 4.937636375427246
Step [250/390]	 Loss: 4.8952226638793945
Step [300/390]	 Loss: 4.968160152435303
Step [350/390]	 Loss: 4.9296722412109375


 11%|█         | 11/100 [14:35<1:58:29, 79.88s/it]

Epoch [10/100]	 Loss: 4.9156657561277735	 lr: 0.14633
Step [0/390]	 Loss: 4.955351829528809
Step [50/390]	 Loss: 4.881499290466309
Step [100/390]	 Loss: 4.971097469329834
Step [150/390]	 Loss: 4.993621826171875
Step [200/390]	 Loss: 4.962740898132324
Step [250/390]	 Loss: 4.855422019958496
Step [300/390]	 Loss: 4.874808311462402
Step [350/390]	 Loss: 4.96935510635376


 12%|█▏        | 12/100 [15:55<1:57:15, 79.95s/it]

Epoch [11/100]	 Loss: 4.914000050226847	 lr: 0.14557
Step [0/390]	 Loss: 4.874512195587158
Step [50/390]	 Loss: 5.0135040283203125
Step [100/390]	 Loss: 5.0058674812316895
Step [150/390]	 Loss: 4.917006969451904
Step [200/390]	 Loss: 4.8702006340026855
Step [250/390]	 Loss: 4.908839702606201
Step [300/390]	 Loss: 4.915426254272461
Step [350/390]	 Loss: 4.799899578094482


 13%|█▎        | 13/100 [17:15<1:56:00, 80.01s/it]

Epoch [12/100]	 Loss: 4.904992788265913	 lr: 0.14473
Step [0/390]	 Loss: 4.931271076202393
Step [50/390]	 Loss: 4.922369003295898
Step [100/390]	 Loss: 4.869243144989014
Step [150/390]	 Loss: 4.9566426277160645
Step [200/390]	 Loss: 4.949307918548584
Step [250/390]	 Loss: 4.843118190765381
Step [300/390]	 Loss: 4.857394218444824
Step [350/390]	 Loss: 4.8670501708984375


 14%|█▍        | 14/100 [18:34<1:54:24, 79.82s/it]

Epoch [13/100]	 Loss: 4.895889417941754	 lr: 0.14383
Step [0/390]	 Loss: 4.963562965393066
Step [50/390]	 Loss: 4.989622592926025
Step [100/390]	 Loss: 4.80924654006958
Step [150/390]	 Loss: 4.946994304656982
Step [200/390]	 Loss: 4.687582015991211
Step [250/390]	 Loss: 4.90040922164917
Step [300/390]	 Loss: 4.962811470031738
Step [350/390]	 Loss: 4.862727642059326


 15%|█▌        | 15/100 [19:54<1:53:02, 79.80s/it]

Epoch [14/100]	 Loss: 4.892977956625131	 lr: 0.14286
Step [0/390]	 Loss: 4.9187774658203125
Step [50/390]	 Loss: 4.9423604011535645
Step [100/390]	 Loss: 4.817951679229736
Step [150/390]	 Loss: 4.881913185119629
Step [200/390]	 Loss: 4.92777681350708
Step [250/390]	 Loss: 4.920962810516357
Step [300/390]	 Loss: 4.9475998878479
Step [350/390]	 Loss: 4.88199520111084


 16%|█▌        | 16/100 [21:14<1:51:36, 79.72s/it]

Epoch [15/100]	 Loss: 4.887099860264705	 lr: 0.14183
Step [0/390]	 Loss: 4.8650970458984375
Step [50/390]	 Loss: 4.940327167510986
Step [100/390]	 Loss: 4.860629558563232
Step [150/390]	 Loss: 4.893383502960205
Step [200/390]	 Loss: 4.816135883331299
Step [250/390]	 Loss: 4.879796981811523
Step [300/390]	 Loss: 4.871893405914307
Step [350/390]	 Loss: 4.873425483703613


 17%|█▋        | 17/100 [22:32<1:49:50, 79.40s/it]

Epoch [16/100]	 Loss: 4.883370156165881	 lr: 0.14072
Step [0/390]	 Loss: 4.80354118347168
Step [50/390]	 Loss: 4.988894462585449
Step [100/390]	 Loss: 4.798681735992432
Step [150/390]	 Loss: 4.845320224761963
Step [200/390]	 Loss: 4.799827575683594
Step [250/390]	 Loss: 4.912813663482666
Step [300/390]	 Loss: 4.827736854553223
Step [350/390]	 Loss: 4.870432376861572


 18%|█▊        | 18/100 [23:51<1:48:27, 79.36s/it]

Epoch [17/100]	 Loss: 4.879869200633123	 lr: 0.13956
Step [0/390]	 Loss: 4.888800144195557
Step [50/390]	 Loss: 4.880702495574951
Step [100/390]	 Loss: 4.821775913238525
Step [150/390]	 Loss: 4.8767547607421875
Step [200/390]	 Loss: 4.844696044921875
Step [250/390]	 Loss: 4.874030113220215
Step [300/390]	 Loss: 4.96914529800415
Step [350/390]	 Loss: 4.785527229309082


 19%|█▉        | 19/100 [25:11<1:47:03, 79.30s/it]

Epoch [18/100]	 Loss: 4.876270925081693	 lr: 0.13832
Step [0/390]	 Loss: 4.925411224365234
Step [50/390]	 Loss: 4.909749984741211
Step [100/390]	 Loss: 4.890962600708008
Step [150/390]	 Loss: 4.9546356201171875
Step [200/390]	 Loss: 4.843392372131348
Step [250/390]	 Loss: 4.811543941497803
Step [300/390]	 Loss: 4.886890888214111
Step [350/390]	 Loss: 4.894262313842773


 20%|██        | 20/100 [26:31<1:46:04, 79.56s/it]

Epoch [19/100]	 Loss: 4.863417739134569	 lr: 0.13703
Step [0/390]	 Loss: 4.81963586807251
Step [50/390]	 Loss: 4.941206932067871
Step [100/390]	 Loss: 4.885904312133789
Step [150/390]	 Loss: 4.888747692108154
Step [200/390]	 Loss: 4.857451438903809
Step [250/390]	 Loss: 4.9148664474487305
Step [300/390]	 Loss: 4.9102983474731445
Step [350/390]	 Loss: 4.880423545837402


 21%|██        | 21/100 [27:51<1:44:49, 79.61s/it]

Epoch [20/100]	 Loss: 4.866018374760945	 lr: 0.13568
Step [0/390]	 Loss: 4.870605945587158
Step [50/390]	 Loss: 4.91276216506958
Step [100/390]	 Loss: 4.829826831817627
Step [150/390]	 Loss: 4.943659782409668
Step [200/390]	 Loss: 4.742379665374756
Step [250/390]	 Loss: 4.828204154968262
Step [300/390]	 Loss: 4.868268013000488
Step [350/390]	 Loss: 4.829588890075684


 22%|██▏       | 22/100 [29:11<1:43:48, 79.85s/it]

Epoch [21/100]	 Loss: 4.863121556013058	 lr: 0.13426
Step [0/390]	 Loss: 4.8392109870910645
Step [50/390]	 Loss: 4.999361038208008
Step [100/390]	 Loss: 4.822728633880615
Step [150/390]	 Loss: 4.89587926864624
Step [200/390]	 Loss: 4.887345790863037
Step [250/390]	 Loss: 4.784997940063477
Step [300/390]	 Loss: 4.872792720794678
Step [350/390]	 Loss: 4.8771257400512695


 23%|██▎       | 23/100 [30:31<1:42:34, 79.93s/it]

Epoch [22/100]	 Loss: 4.858826389068212	 lr: 0.13279
Step [0/390]	 Loss: 4.884120941162109
Step [50/390]	 Loss: 4.964766025543213
Step [100/390]	 Loss: 4.784839153289795
Step [150/390]	 Loss: 4.908825874328613
Step [200/390]	 Loss: 4.866337776184082
Step [250/390]	 Loss: 4.870205879211426
Step [300/390]	 Loss: 4.852588653564453
Step [350/390]	 Loss: 4.8901872634887695


 24%|██▍       | 24/100 [31:51<1:41:25, 80.07s/it]

Epoch [23/100]	 Loss: 4.852334688871335	 lr: 0.13126
Step [0/390]	 Loss: 4.854982852935791
Step [50/390]	 Loss: 4.907201290130615
Step [100/390]	 Loss: 4.81643533706665
Step [150/390]	 Loss: 4.907136917114258
Step [200/390]	 Loss: 4.836904048919678
Step [250/390]	 Loss: 4.878150939941406
Step [300/390]	 Loss: 4.769501686096191
Step [350/390]	 Loss: 4.818353176116943


 25%|██▌       | 25/100 [33:11<1:39:57, 79.97s/it]

Epoch [24/100]	 Loss: 4.852097582205748	 lr: 0.12967
Step [0/390]	 Loss: 4.776683330535889
Step [50/390]	 Loss: 4.894947528839111
Step [100/390]	 Loss: 4.844770431518555
Step [150/390]	 Loss: 4.8864617347717285
Step [200/390]	 Loss: 4.867029666900635
Step [250/390]	 Loss: 4.854527950286865
Step [300/390]	 Loss: 4.9280571937561035
Step [350/390]	 Loss: 4.892309188842773


 26%|██▌       | 26/100 [34:32<1:38:50, 80.14s/it]

Epoch [25/100]	 Loss: 4.849935032771184	 lr: 0.12803
Step [0/390]	 Loss: 4.9037580490112305
Step [50/390]	 Loss: 4.91599178314209
Step [100/390]	 Loss: 4.846837043762207
Step [150/390]	 Loss: 4.792449951171875
Step [200/390]	 Loss: 4.799816608428955
Step [250/390]	 Loss: 4.899838924407959
Step [300/390]	 Loss: 4.868453025817871
Step [350/390]	 Loss: 4.866967678070068


 27%|██▋       | 27/100 [35:52<1:37:29, 80.13s/it]

Epoch [26/100]	 Loss: 4.8496472150851515	 lr: 0.12634
Step [0/390]	 Loss: 4.872795104980469
Step [50/390]	 Loss: 4.875548839569092
Step [100/390]	 Loss: 4.777347564697266
Step [150/390]	 Loss: 4.907937049865723
Step [200/390]	 Loss: 4.822393894195557
Step [250/390]	 Loss: 4.85007381439209
Step [300/390]	 Loss: 4.843885898590088
Step [350/390]	 Loss: 4.777773380279541


 28%|██▊       | 28/100 [37:10<1:35:34, 79.64s/it]

Epoch [27/100]	 Loss: 4.847058354891264	 lr: 0.1246
Step [0/390]	 Loss: 4.880369186401367
Step [50/390]	 Loss: 4.918598651885986
Step [100/390]	 Loss: 4.811778545379639
Step [150/390]	 Loss: 4.81776237487793
Step [200/390]	 Loss: 4.753493309020996
Step [250/390]	 Loss: 4.821118354797363
Step [300/390]	 Loss: 4.819451808929443
Step [350/390]	 Loss: 4.845674514770508


 29%|██▉       | 29/100 [38:30<1:34:21, 79.74s/it]

Epoch [28/100]	 Loss: 4.843096863917816	 lr: 0.12281
Step [0/390]	 Loss: 4.833542823791504
Step [50/390]	 Loss: 4.986169338226318
Step [100/390]	 Loss: 4.7901482582092285
Step [150/390]	 Loss: 4.871285438537598
Step [200/390]	 Loss: 4.921655654907227
Step [250/390]	 Loss: 4.859673976898193
Step [300/390]	 Loss: 4.809645175933838
Step [350/390]	 Loss: 4.954831123352051


 30%|███       | 30/100 [39:50<1:33:07, 79.82s/it]

Epoch [29/100]	 Loss: 4.839883195436918	 lr: 0.12097
Step [0/390]	 Loss: 4.871889591217041
Step [50/390]	 Loss: 4.867625713348389
Step [100/390]	 Loss: 4.828607559204102
Step [150/390]	 Loss: 4.796553611755371
Step [200/390]	 Loss: 4.847245693206787
Step [250/390]	 Loss: 4.8506293296813965
Step [300/390]	 Loss: 4.87748908996582
Step [350/390]	 Loss: 4.804276466369629


 31%|███       | 31/100 [41:12<1:32:29, 80.43s/it]

Epoch [30/100]	 Loss: 4.835371038241264	 lr: 0.11908
Step [0/390]	 Loss: 4.8658623695373535
Step [50/390]	 Loss: 4.847934246063232
Step [100/390]	 Loss: 4.831879615783691
Step [150/390]	 Loss: 4.821375846862793
Step [200/390]	 Loss: 4.8301215171813965
Step [250/390]	 Loss: 4.782459735870361
Step [300/390]	 Loss: 4.7444000244140625
Step [350/390]	 Loss: 4.804419994354248


 32%|███▏      | 32/100 [42:32<1:31:02, 80.33s/it]

Epoch [31/100]	 Loss: 4.833282503714928	 lr: 0.11716
Step [0/390]	 Loss: 4.791171550750732
Step [50/390]	 Loss: 4.819235801696777
Step [100/390]	 Loss: 4.824597358703613
Step [150/390]	 Loss: 4.853602886199951
Step [200/390]	 Loss: 4.839247703552246
Step [250/390]	 Loss: 4.867538928985596
Step [300/390]	 Loss: 4.919516086578369
Step [350/390]	 Loss: 4.783200263977051


 33%|███▎      | 33/100 [43:53<1:29:58, 80.57s/it]

Epoch [32/100]	 Loss: 4.833731152461126	 lr: 0.11519
Step [0/390]	 Loss: 4.678091049194336
Step [50/390]	 Loss: 4.763378143310547
Step [100/390]	 Loss: 4.788016319274902
Step [150/390]	 Loss: 4.851253509521484
Step [200/390]	 Loss: 4.94719123840332
Step [250/390]	 Loss: 4.7831878662109375
Step [300/390]	 Loss: 4.82396125793457
Step [350/390]	 Loss: 4.855517387390137


 34%|███▍      | 34/100 [45:14<1:28:33, 80.51s/it]

Epoch [33/100]	 Loss: 4.829307433886406	 lr: 0.11318
Step [0/390]	 Loss: 4.859655380249023
Step [50/390]	 Loss: 4.935211181640625
Step [100/390]	 Loss: 4.8259711265563965
Step [150/390]	 Loss: 4.885563850402832
Step [200/390]	 Loss: 4.771472454071045
Step [250/390]	 Loss: 4.765807151794434
Step [300/390]	 Loss: 4.769116401672363
Step [350/390]	 Loss: 4.756730556488037


 35%|███▌      | 35/100 [46:34<1:27:06, 80.41s/it]

Epoch [34/100]	 Loss: 4.8289508354969515	 lr: 0.11113
Step [0/390]	 Loss: 4.907063961029053
Step [50/390]	 Loss: 4.890004634857178
Step [100/390]	 Loss: 4.817840099334717
Step [150/390]	 Loss: 4.878933429718018
Step [200/390]	 Loss: 4.803094387054443
Step [250/390]	 Loss: 4.920827388763428
Step [300/390]	 Loss: 4.7519917488098145
Step [350/390]	 Loss: 4.756985187530518


 36%|███▌      | 36/100 [47:54<1:25:40, 80.32s/it]

Epoch [35/100]	 Loss: 4.825876167492988	 lr: 0.10905
Step [0/390]	 Loss: 4.922673225402832
Step [50/390]	 Loss: 4.858007431030273
Step [100/390]	 Loss: 4.727980136871338
Step [150/390]	 Loss: 4.755066394805908
Step [200/390]	 Loss: 4.841015815734863
Step [250/390]	 Loss: 4.821715831756592
Step [300/390]	 Loss: 4.829334259033203
Step [350/390]	 Loss: 4.8307318687438965


 37%|███▋      | 37/100 [49:14<1:24:15, 80.25s/it]

Epoch [36/100]	 Loss: 4.823526721122938	 lr: 0.10693
Step [0/390]	 Loss: 4.806889057159424
Step [50/390]	 Loss: 4.875023365020752
Step [100/390]	 Loss: 4.807910919189453
Step [150/390]	 Loss: 4.77818489074707
Step [200/390]	 Loss: 4.783862113952637
Step [250/390]	 Loss: 4.79443359375
Step [300/390]	 Loss: 4.845841407775879
Step [350/390]	 Loss: 4.794828414916992


 38%|███▊      | 38/100 [50:33<1:22:30, 79.85s/it]

Epoch [37/100]	 Loss: 4.8236920992533365	 lr: 0.10479
Step [0/390]	 Loss: 4.745716094970703
Step [50/390]	 Loss: 4.973682880401611
Step [100/390]	 Loss: 4.770694732666016
Step [150/390]	 Loss: 4.742032051086426
Step [200/390]	 Loss: 4.781137943267822
Step [250/390]	 Loss: 4.8393073081970215
Step [300/390]	 Loss: 4.915913105010986
Step [350/390]	 Loss: 4.832607269287109


 39%|███▉      | 39/100 [51:52<1:20:59, 79.66s/it]

Epoch [38/100]	 Loss: 4.819757586259108	 lr: 0.10261
Step [0/390]	 Loss: 4.759389877319336
Step [50/390]	 Loss: 4.945352077484131
Step [100/390]	 Loss: 4.809445858001709
Step [150/390]	 Loss: 4.801377773284912
Step [200/390]	 Loss: 4.789337158203125
Step [250/390]	 Loss: 4.673856735229492
Step [300/390]	 Loss: 4.8332839012146
Step [350/390]	 Loss: 4.7883405685424805


 40%|████      | 40/100 [53:11<1:19:24, 79.41s/it]

Epoch [39/100]	 Loss: 4.8147507080665	 lr: 0.10041
Step [0/390]	 Loss: 4.8848557472229
Step [50/390]	 Loss: 4.888916015625
Step [100/390]	 Loss: 4.764031887054443
Step [150/390]	 Loss: 4.877956390380859
Step [200/390]	 Loss: 4.809973239898682
Step [250/390]	 Loss: 4.785903453826904
Step [300/390]	 Loss: 4.764732837677002
Step [350/390]	 Loss: 4.903513431549072


 41%|████      | 41/100 [54:31<1:18:18, 79.64s/it]

Epoch [40/100]	 Loss: 4.8156037306174255	 lr: 0.09818
Step [0/390]	 Loss: 4.859389781951904
Step [50/390]	 Loss: 4.78380823135376
Step [100/390]	 Loss: 4.777554512023926
Step [150/390]	 Loss: 4.853231430053711
Step [200/390]	 Loss: 4.7577643394470215
Step [250/390]	 Loss: 4.7825422286987305
Step [300/390]	 Loss: 4.848135471343994
Step [350/390]	 Loss: 4.8371968269348145


 42%|████▏     | 42/100 [55:52<1:17:15, 79.93s/it]

Epoch [41/100]	 Loss: 4.808206940919925	 lr: 0.09592
Step [0/390]	 Loss: 4.870394229888916
Step [50/390]	 Loss: 4.746194839477539
Step [100/390]	 Loss: 4.8231635093688965
Step [150/390]	 Loss: 4.851089954376221
Step [200/390]	 Loss: 4.8589186668396
Step [250/390]	 Loss: 4.763432025909424
Step [300/390]	 Loss: 4.772739887237549
Step [350/390]	 Loss: 4.789035320281982


 43%|████▎     | 43/100 [57:12<1:16:00, 80.00s/it]

Epoch [42/100]	 Loss: 4.814425617609269	 lr: 0.09365
Step [0/390]	 Loss: 4.821849822998047
Step [50/390]	 Loss: 4.818177223205566
Step [100/390]	 Loss: 4.734741687774658
Step [150/390]	 Loss: 4.837981700897217
Step [200/390]	 Loss: 4.76930046081543
Step [250/390]	 Loss: 4.761483669281006
Step [300/390]	 Loss: 4.879786968231201
Step [350/390]	 Loss: 4.742130756378174


 44%|████▍     | 44/100 [58:33<1:15:03, 80.42s/it]

Epoch [43/100]	 Loss: 4.805530512638581	 lr: 0.09136
Step [0/390]	 Loss: 4.842284202575684
Step [50/390]	 Loss: 4.889620780944824
Step [100/390]	 Loss: 4.806737899780273
Step [150/390]	 Loss: 4.826907157897949
Step [200/390]	 Loss: 4.785515308380127
Step [250/390]	 Loss: 4.807455062866211
Step [300/390]	 Loss: 4.88357400894165
Step [350/390]	 Loss: 4.786941051483154


 45%|████▌     | 45/100 [59:54<1:13:39, 80.36s/it]

Epoch [44/100]	 Loss: 4.811325289652898	 lr: 0.08905
Step [0/390]	 Loss: 4.791171073913574
Step [50/390]	 Loss: 4.809182643890381
Step [100/390]	 Loss: 4.8062238693237305
Step [150/390]	 Loss: 4.8713059425354
Step [200/390]	 Loss: 4.768840789794922
Step [250/390]	 Loss: 4.712364673614502
Step [300/390]	 Loss: 4.790351390838623
Step [350/390]	 Loss: 4.74245548248291


 46%|████▌     | 46/100 [1:01:14<1:12:15, 80.29s/it]

Epoch [45/100]	 Loss: 4.803899668424558	 lr: 0.08673
Step [0/390]	 Loss: 4.8614959716796875
Step [50/390]	 Loss: 4.791454792022705
Step [100/390]	 Loss: 4.717182159423828
Step [150/390]	 Loss: 4.792861461639404
Step [200/390]	 Loss: 4.83125114440918
Step [250/390]	 Loss: 4.742708683013916
Step [300/390]	 Loss: 4.916377544403076
Step [350/390]	 Loss: 4.687212944030762


 47%|████▋     | 47/100 [1:02:34<1:10:52, 80.23s/it]

Epoch [46/100]	 Loss: 4.810975988094623	 lr: 0.0844
Step [0/390]	 Loss: 4.7958292961120605
Step [50/390]	 Loss: 4.758151531219482
Step [100/390]	 Loss: 4.810645580291748
Step [150/390]	 Loss: 4.844700813293457
Step [200/390]	 Loss: 4.86786413192749
Step [250/390]	 Loss: 4.833565711975098
Step [300/390]	 Loss: 4.868598937988281
Step [350/390]	 Loss: 4.80319881439209


 48%|████▊     | 48/100 [1:03:54<1:09:37, 80.34s/it]

Epoch [47/100]	 Loss: 4.806131663689246	 lr: 0.08206
Step [0/390]	 Loss: 4.779226779937744
Step [50/390]	 Loss: 4.774426460266113
Step [100/390]	 Loss: 4.7889604568481445
Step [150/390]	 Loss: 4.859241008758545
Step [200/390]	 Loss: 4.725976943969727
Step [250/390]	 Loss: 4.757786273956299
Step [300/390]	 Loss: 4.781139373779297
Step [350/390]	 Loss: 4.838066101074219


 49%|████▉     | 49/100 [1:05:14<1:08:11, 80.23s/it]

Epoch [48/100]	 Loss: 4.8022793916555555	 lr: 0.07971
Step [0/390]	 Loss: 4.830896377563477
Step [50/390]	 Loss: 4.779727935791016
Step [100/390]	 Loss: 4.7408552169799805
Step [150/390]	 Loss: 4.6965651512146
Step [200/390]	 Loss: 4.84443473815918
Step [250/390]	 Loss: 4.80403995513916
Step [300/390]	 Loss: 4.7865753173828125
Step [350/390]	 Loss: 4.771652698516846


 50%|█████     | 50/100 [1:06:34<1:06:35, 79.91s/it]

Epoch [49/100]	 Loss: 4.799119193737323	 lr: 0.07736
Step [0/390]	 Loss: 4.803257465362549
Step [50/390]	 Loss: 4.864778518676758
Step [100/390]	 Loss: 4.754868507385254
Step [150/390]	 Loss: 4.793455600738525
Step [200/390]	 Loss: 4.772823333740234
Step [250/390]	 Loss: 4.819356441497803
Step [300/390]	 Loss: 4.808899402618408
Step [350/390]	 Loss: 4.695580959320068


 51%|█████     | 51/100 [1:07:54<1:05:21, 80.04s/it]

Epoch [50/100]	 Loss: 4.799575837453206	 lr: 0.075
Step [0/390]	 Loss: 4.888085842132568
Step [50/390]	 Loss: 4.864787578582764
Step [100/390]	 Loss: 4.749118804931641
Step [150/390]	 Loss: 4.750174045562744
Step [200/390]	 Loss: 4.807921409606934
Step [250/390]	 Loss: 4.920691013336182
Step [300/390]	 Loss: 4.820240020751953
Step [350/390]	 Loss: 4.801974773406982


 52%|█████▏    | 52/100 [1:09:14<1:04:05, 80.12s/it]

Epoch [51/100]	 Loss: 4.796852934666169	 lr: 0.07264
Step [0/390]	 Loss: 4.7373552322387695
Step [50/390]	 Loss: 4.862214088439941
Step [100/390]	 Loss: 4.716655731201172
Step [150/390]	 Loss: 4.815762519836426
Step [200/390]	 Loss: 4.8118896484375
Step [250/390]	 Loss: 4.767536163330078
Step [300/390]	 Loss: 4.695971488952637
Step [350/390]	 Loss: 4.760157585144043


 53%|█████▎    | 53/100 [1:10:35<1:02:51, 80.25s/it]

Epoch [52/100]	 Loss: 4.797621255043225	 lr: 0.07029
Step [0/390]	 Loss: 4.706611156463623
Step [50/390]	 Loss: 4.842843055725098
Step [100/390]	 Loss: 4.845355987548828
Step [150/390]	 Loss: 4.7927985191345215
Step [200/390]	 Loss: 4.805053234100342
Step [250/390]	 Loss: 4.776185989379883
Step [300/390]	 Loss: 4.795899391174316
Step [350/390]	 Loss: 4.677387237548828


 54%|█████▍    | 54/100 [1:11:55<1:01:24, 80.10s/it]

Epoch [53/100]	 Loss: 4.791744075677333	 lr: 0.06794
Step [0/390]	 Loss: 4.902066707611084
Step [50/390]	 Loss: 4.792440891265869
Step [100/390]	 Loss: 4.753053188323975
Step [150/390]	 Loss: 4.829868793487549
Step [200/390]	 Loss: 4.8318915367126465
Step [250/390]	 Loss: 4.832878112792969
Step [300/390]	 Loss: 4.7751874923706055
Step [350/390]	 Loss: 4.767620086669922


 55%|█████▌    | 55/100 [1:13:16<1:00:21, 80.48s/it]

Epoch [54/100]	 Loss: 4.792961838306525	 lr: 0.0656
Step [0/390]	 Loss: 4.8439764976501465
Step [50/390]	 Loss: 4.867824554443359
Step [100/390]	 Loss: 4.768752098083496
Step [150/390]	 Loss: 4.856253623962402
Step [200/390]	 Loss: 4.760541915893555
Step [250/390]	 Loss: 4.800450325012207
Step [300/390]	 Loss: 4.850249290466309
Step [350/390]	 Loss: 4.77191162109375


 56%|█████▌    | 56/100 [1:14:34<58:35, 79.90s/it]  

Epoch [55/100]	 Loss: 4.794119517008464	 lr: 0.06327
Step [0/390]	 Loss: 4.784792900085449
Step [50/390]	 Loss: 4.841598987579346
Step [100/390]	 Loss: 4.773364543914795
Step [150/390]	 Loss: 4.8124589920043945
Step [200/390]	 Loss: 4.916037559509277
Step [250/390]	 Loss: 4.802045822143555
Step [300/390]	 Loss: 4.853157043457031
Step [350/390]	 Loss: 4.7751946449279785


 57%|█████▋    | 57/100 [1:15:53<57:00, 79.55s/it]

Epoch [56/100]	 Loss: 4.7914601032550515	 lr: 0.06095
Step [0/390]	 Loss: 4.688639163970947
Step [50/390]	 Loss: 4.787271499633789
Step [100/390]	 Loss: 4.764842510223389
Step [150/390]	 Loss: 4.855223655700684
Step [200/390]	 Loss: 4.816818714141846
Step [250/390]	 Loss: 4.775785446166992
Step [300/390]	 Loss: 4.814448833465576
Step [350/390]	 Loss: 4.77771520614624


 58%|█████▊    | 58/100 [1:17:11<55:23, 79.13s/it]

Epoch [57/100]	 Loss: 4.788465245564779	 lr: 0.05864
Step [0/390]	 Loss: 4.857124328613281
Step [50/390]	 Loss: 4.749012470245361
Step [100/390]	 Loss: 4.687756538391113
Step [150/390]	 Loss: 4.773930549621582
Step [200/390]	 Loss: 4.859220027923584
Step [250/390]	 Loss: 4.758362770080566
Step [300/390]	 Loss: 4.817574501037598
Step [350/390]	 Loss: 4.715416431427002


 59%|█████▉    | 59/100 [1:18:30<54:00, 79.03s/it]

Epoch [58/100]	 Loss: 4.782261742078341	 lr: 0.05635
Step [0/390]	 Loss: 4.799408435821533
Step [50/390]	 Loss: 4.843457221984863
Step [100/390]	 Loss: 4.690548896789551
Step [150/390]	 Loss: 4.82515287399292
Step [200/390]	 Loss: 4.6935319900512695
Step [250/390]	 Loss: 4.810828685760498
Step [300/390]	 Loss: 4.791777610778809
Step [350/390]	 Loss: 4.636219501495361


 60%|██████    | 60/100 [1:19:48<52:28, 78.72s/it]

Epoch [59/100]	 Loss: 4.787316663448627	 lr: 0.05408
Step [0/390]	 Loss: 4.820061206817627
Step [50/390]	 Loss: 4.844487190246582
Step [100/390]	 Loss: 4.789336204528809
Step [150/390]	 Loss: 4.768209934234619
Step [200/390]	 Loss: 4.767251014709473
Step [250/390]	 Loss: 4.772615432739258
Step [300/390]	 Loss: 4.78759241104126
Step [350/390]	 Loss: 4.70219087600708


 61%|██████    | 61/100 [1:21:06<50:57, 78.39s/it]

Epoch [60/100]	 Loss: 4.789539615924542	 lr: 0.05182
Step [0/390]	 Loss: 4.695736408233643
Step [50/390]	 Loss: 4.874308109283447
Step [100/390]	 Loss: 4.792397975921631
Step [150/390]	 Loss: 4.807600021362305
Step [200/390]	 Loss: 4.744405269622803
Step [250/390]	 Loss: 4.866097450256348
Step [300/390]	 Loss: 4.819290637969971
Step [350/390]	 Loss: 4.7444610595703125


 62%|██████▏   | 62/100 [1:22:23<49:28, 78.13s/it]

Epoch [61/100]	 Loss: 4.7866011692927435	 lr: 0.04959
Step [0/390]	 Loss: 4.888483047485352
Step [50/390]	 Loss: 4.761117458343506
Step [100/390]	 Loss: 4.656728267669678
Step [150/390]	 Loss: 4.768110752105713
Step [200/390]	 Loss: 4.784303188323975
Step [250/390]	 Loss: 4.782151222229004
Step [300/390]	 Loss: 4.734480381011963
Step [350/390]	 Loss: 4.773523330688477


 63%|██████▎   | 63/100 [1:23:40<47:55, 77.71s/it]

Epoch [62/100]	 Loss: 4.7811293394137655	 lr: 0.04739
Step [0/390]	 Loss: 4.879875659942627
Step [50/390]	 Loss: 4.792293071746826
Step [100/390]	 Loss: 4.823534965515137
Step [150/390]	 Loss: 4.785086154937744
Step [200/390]	 Loss: 4.795567035675049
Step [250/390]	 Loss: 4.789759159088135
Step [300/390]	 Loss: 4.852727890014648
Step [350/390]	 Loss: 4.7839274406433105


 64%|██████▍   | 64/100 [1:24:58<46:36, 77.67s/it]

Epoch [63/100]	 Loss: 4.785118323106032	 lr: 0.04521
Step [0/390]	 Loss: 4.690757751464844
Step [50/390]	 Loss: 4.789544105529785
Step [100/390]	 Loss: 4.756833553314209
Step [150/390]	 Loss: 4.791020393371582
Step [200/390]	 Loss: 4.781803131103516
Step [250/390]	 Loss: 4.781594753265381
Step [300/390]	 Loss: 4.853687763214111
Step [350/390]	 Loss: 4.762060165405273


 65%|██████▌   | 65/100 [1:26:14<45:08, 77.40s/it]

Epoch [64/100]	 Loss: 4.782819983898065	 lr: 0.04307
Step [0/390]	 Loss: 4.7315239906311035
Step [50/390]	 Loss: 4.835689067840576
Step [100/390]	 Loss: 4.803620338439941
Step [150/390]	 Loss: 4.823591232299805
Step [200/390]	 Loss: 4.817276477813721
Step [250/390]	 Loss: 4.696132659912109
Step [300/390]	 Loss: 4.820062160491943
Step [350/390]	 Loss: 4.769502639770508


 66%|██████▌   | 66/100 [1:27:32<43:52, 77.42s/it]

Epoch [65/100]	 Loss: 4.77991884182661	 lr: 0.04095
Step [0/390]	 Loss: 4.792171001434326
Step [50/390]	 Loss: 4.769537448883057
Step [100/390]	 Loss: 4.72451114654541
Step [150/390]	 Loss: 4.784765243530273
Step [200/390]	 Loss: 4.739223480224609
Step [250/390]	 Loss: 4.817729473114014
Step [300/390]	 Loss: 4.799874305725098
Step [350/390]	 Loss: 4.755814075469971


 67%|██████▋   | 67/100 [1:28:48<42:26, 77.18s/it]

Epoch [66/100]	 Loss: 4.775968034450824	 lr: 0.03887
Step [0/390]	 Loss: 4.840329170227051
Step [50/390]	 Loss: 4.812402248382568
Step [100/390]	 Loss: 4.70973014831543
Step [150/390]	 Loss: 4.8613057136535645
Step [200/390]	 Loss: 4.750174522399902
Step [250/390]	 Loss: 4.717702865600586
Step [300/390]	 Loss: 4.800720691680908
Step [350/390]	 Loss: 4.712151050567627


 68%|██████▊   | 68/100 [1:30:05<41:02, 76.97s/it]

Epoch [67/100]	 Loss: 4.776996913323035	 lr: 0.03682
Step [0/390]	 Loss: 4.823721408843994
Step [50/390]	 Loss: 4.848724842071533
Step [100/390]	 Loss: 4.689579010009766
Step [150/390]	 Loss: 4.7900543212890625
Step [200/390]	 Loss: 4.755160331726074
Step [250/390]	 Loss: 4.7760329246521
Step [300/390]	 Loss: 4.77465295791626
Step [350/390]	 Loss: 4.798854351043701


 69%|██████▉   | 69/100 [1:31:22<39:50, 77.12s/it]

Epoch [68/100]	 Loss: 4.777940995876605	 lr: 0.03481
Step [0/390]	 Loss: 4.688852787017822
Step [50/390]	 Loss: 4.850841522216797
Step [100/390]	 Loss: 4.738046646118164
Step [150/390]	 Loss: 4.778196811676025
Step [200/390]	 Loss: 4.800258636474609
Step [250/390]	 Loss: 4.836348533630371
Step [300/390]	 Loss: 4.8026862144470215
Step [350/390]	 Loss: 4.792506694793701


 70%|███████   | 70/100 [1:32:40<38:36, 77.22s/it]

Epoch [69/100]	 Loss: 4.776214631398519	 lr: 0.03284
Step [0/390]	 Loss: 4.798492908477783
Step [50/390]	 Loss: 4.797236919403076
Step [100/390]	 Loss: 4.760593891143799
Step [150/390]	 Loss: 4.842412948608398
Step [200/390]	 Loss: 4.736880779266357
Step [250/390]	 Loss: 4.777628421783447
Step [300/390]	 Loss: 4.838723659515381
Step [350/390]	 Loss: 4.703667640686035


 71%|███████   | 71/100 [1:33:57<37:22, 77.32s/it]

Epoch [70/100]	 Loss: 4.7726188781933905	 lr: 0.03092
Step [0/390]	 Loss: 4.81231689453125
Step [50/390]	 Loss: 4.808975696563721
Step [100/390]	 Loss: 4.712919235229492
Step [150/390]	 Loss: 4.786257266998291
Step [200/390]	 Loss: 4.751509189605713
Step [250/390]	 Loss: 4.8061957359313965
Step [300/390]	 Loss: 4.780759811401367
Step [350/390]	 Loss: 4.800724506378174


 72%|███████▏  | 72/100 [1:35:14<36:01, 77.20s/it]

Epoch [71/100]	 Loss: 4.769652420435197	 lr: 0.02903
Step [0/390]	 Loss: 4.741531848907471
Step [50/390]	 Loss: 4.835453033447266
Step [100/390]	 Loss: 4.750783443450928
Step [150/390]	 Loss: 4.796507835388184
Step [200/390]	 Loss: 4.784970760345459
Step [250/390]	 Loss: 4.7690253257751465
Step [300/390]	 Loss: 4.73044490814209
Step [350/390]	 Loss: 4.7725911140441895


 73%|███████▎  | 73/100 [1:36:31<34:40, 77.06s/it]

Epoch [72/100]	 Loss: 4.776446722715329	 lr: 0.02719
Step [0/390]	 Loss: 4.715878486633301
Step [50/390]	 Loss: 4.8376851081848145
Step [100/390]	 Loss: 4.721851348876953
Step [150/390]	 Loss: 4.740993499755859
Step [200/390]	 Loss: 4.76612663269043
Step [250/390]	 Loss: 4.810049533843994
Step [300/390]	 Loss: 4.787276744842529
Step [350/390]	 Loss: 4.710728168487549


 74%|███████▍  | 74/100 [1:37:47<33:17, 76.84s/it]

Epoch [73/100]	 Loss: 4.7743708219283665	 lr: 0.0254
Step [0/390]	 Loss: 4.852401256561279
Step [50/390]	 Loss: 4.853151321411133
Step [100/390]	 Loss: 4.669278144836426
Step [150/390]	 Loss: 4.79448938369751
Step [200/390]	 Loss: 4.83576774597168
Step [250/390]	 Loss: 4.7230143547058105
Step [300/390]	 Loss: 4.8405585289001465
Step [350/390]	 Loss: 4.717532157897949


 75%|███████▌  | 75/100 [1:39:04<31:59, 76.77s/it]

Epoch [74/100]	 Loss: 4.7740401598123405	 lr: 0.02366
Step [0/390]	 Loss: 4.81102991104126
Step [50/390]	 Loss: 4.827579975128174
Step [100/390]	 Loss: 4.748167514801025
Step [150/390]	 Loss: 4.792995929718018
Step [200/390]	 Loss: 4.8052239418029785
Step [250/390]	 Loss: 4.796210289001465
Step [300/390]	 Loss: 4.762868881225586
Step [350/390]	 Loss: 4.721431732177734


 76%|███████▌  | 76/100 [1:40:22<30:49, 77.07s/it]

Epoch [75/100]	 Loss: 4.772078764744294	 lr: 0.02197
Step [0/390]	 Loss: 4.843206405639648
Step [50/390]	 Loss: 4.835709571838379
Step [100/390]	 Loss: 4.7363080978393555
Step [150/390]	 Loss: 4.814302444458008
Step [200/390]	 Loss: 4.730990409851074
Step [250/390]	 Loss: 4.702574253082275
Step [300/390]	 Loss: 4.763994216918945
Step [350/390]	 Loss: 4.701466083526611


 77%|███████▋  | 77/100 [1:41:39<29:30, 76.99s/it]

Epoch [76/100]	 Loss: 4.773648474766658	 lr: 0.02033
Step [0/390]	 Loss: 4.696099281311035
Step [50/390]	 Loss: 4.835065841674805
Step [100/390]	 Loss: 4.794012069702148
Step [150/390]	 Loss: 4.690007209777832
Step [200/390]	 Loss: 4.768135070800781
Step [250/390]	 Loss: 4.779165267944336
Step [300/390]	 Loss: 4.741305351257324
Step [350/390]	 Loss: 4.791328430175781


 78%|███████▊  | 78/100 [1:42:55<28:13, 76.96s/it]

Epoch [77/100]	 Loss: 4.770514406302037	 lr: 0.01874
Step [0/390]	 Loss: 4.875699520111084
Step [50/390]	 Loss: 4.692551612854004
Step [100/390]	 Loss: 4.745932102203369
Step [150/390]	 Loss: 4.719049453735352
Step [200/390]	 Loss: 4.792635917663574
Step [250/390]	 Loss: 4.772089004516602
Step [300/390]	 Loss: 4.7836456298828125
Step [350/390]	 Loss: 4.7519073486328125


 79%|███████▉  | 79/100 [1:44:12<26:55, 76.94s/it]

Epoch [78/100]	 Loss: 4.771986979704637	 lr: 0.01721
Step [0/390]	 Loss: 4.7693705558776855
Step [50/390]	 Loss: 4.821074962615967
Step [100/390]	 Loss: 4.7645087242126465
Step [150/390]	 Loss: 4.808010101318359
Step [200/390]	 Loss: 4.7606072425842285
Step [250/390]	 Loss: 4.747913360595703
Step [300/390]	 Loss: 4.736688137054443
Step [350/390]	 Loss: 4.813571929931641


 80%|████████  | 80/100 [1:45:29<25:40, 77.00s/it]

Epoch [79/100]	 Loss: 4.76879880856245	 lr: 0.01574
Step [0/390]	 Loss: 4.8394670486450195
Step [50/390]	 Loss: 4.815471649169922
Step [100/390]	 Loss: 4.70981502532959
Step [150/390]	 Loss: 4.765129566192627
Step [200/390]	 Loss: 4.8077497482299805
Step [250/390]	 Loss: 4.8776092529296875
Step [300/390]	 Loss: 4.742385387420654
Step [350/390]	 Loss: 4.784123420715332


 81%|████████  | 81/100 [1:46:47<24:25, 77.11s/it]

Epoch [80/100]	 Loss: 4.766948633927566	 lr: 0.01432
Step [0/390]	 Loss: 4.735583305358887
Step [50/390]	 Loss: 4.750688076019287
Step [100/390]	 Loss: 4.779519081115723
Step [150/390]	 Loss: 4.8151774406433105
Step [200/390]	 Loss: 4.743139743804932
Step [250/390]	 Loss: 4.743422985076904
Step [300/390]	 Loss: 4.839510917663574
Step [350/390]	 Loss: 4.757187843322754


 82%|████████▏ | 82/100 [1:48:04<23:07, 77.09s/it]

Epoch [81/100]	 Loss: 4.773110635464008	 lr: 0.01297
Step [0/390]	 Loss: 4.711407661437988
Step [50/390]	 Loss: 4.841849327087402
Step [100/390]	 Loss: 4.722117900848389
Step [150/390]	 Loss: 4.734948635101318
Step [200/390]	 Loss: 4.826355934143066
Step [250/390]	 Loss: 4.727850437164307
Step [300/390]	 Loss: 4.755346775054932
Step [350/390]	 Loss: 4.767836093902588


 83%|████████▎ | 83/100 [1:49:22<21:56, 77.43s/it]

Epoch [82/100]	 Loss: 4.7665194621452915	 lr: 0.01168
Step [0/390]	 Loss: 4.729568958282471
Step [50/390]	 Loss: 4.784572601318359
Step [100/390]	 Loss: 4.835345268249512
Step [150/390]	 Loss: 4.747394561767578
Step [200/390]	 Loss: 4.747801780700684
Step [250/390]	 Loss: 4.764652729034424
Step [300/390]	 Loss: 4.672544956207275
Step [350/390]	 Loss: 4.771711349487305


 84%|████████▍ | 84/100 [1:50:39<20:36, 77.30s/it]

Epoch [83/100]	 Loss: 4.768558439841637	 lr: 0.01044
Step [0/390]	 Loss: 4.822289943695068
Step [50/390]	 Loss: 4.806178092956543
Step [100/390]	 Loss: 4.740512371063232
Step [150/390]	 Loss: 4.80312442779541
Step [200/390]	 Loss: 4.755184173583984
Step [250/390]	 Loss: 4.7500786781311035
Step [300/390]	 Loss: 4.786617279052734
Step [350/390]	 Loss: 4.7089314460754395


 85%|████████▌ | 85/100 [1:51:57<19:22, 77.48s/it]

Epoch [84/100]	 Loss: 4.765638762253982	 lr: 0.00928
Step [0/390]	 Loss: 4.823784828186035
Step [50/390]	 Loss: 4.839721202850342
Step [100/390]	 Loss: 4.700499534606934
Step [150/390]	 Loss: 4.815722942352295
Step [200/390]	 Loss: 4.786620140075684
Step [250/390]	 Loss: 4.790695667266846
Step [300/390]	 Loss: 4.764983654022217
Step [350/390]	 Loss: 4.734499931335449


 86%|████████▌ | 86/100 [1:53:15<18:06, 77.61s/it]

Epoch [85/100]	 Loss: 4.770969092540252	 lr: 0.00817
Step [0/390]	 Loss: 4.698449611663818
Step [50/390]	 Loss: 4.758927345275879
Step [100/390]	 Loss: 4.789244651794434
Step [150/390]	 Loss: 4.751227855682373
Step [200/390]	 Loss: 4.712223529815674
Step [250/390]	 Loss: 4.780503273010254
Step [300/390]	 Loss: 4.818065166473389
Step [350/390]	 Loss: 4.761908054351807


 87%|████████▋ | 87/100 [1:54:33<16:50, 77.71s/it]

Epoch [86/100]	 Loss: 4.771022196305104	 lr: 0.00714
Step [0/390]	 Loss: 4.721230983734131
Step [50/390]	 Loss: 4.754908561706543
Step [100/390]	 Loss: 4.724594593048096
Step [150/390]	 Loss: 4.775038242340088
Step [200/390]	 Loss: 4.7032670974731445
Step [250/390]	 Loss: 4.693313121795654
Step [300/390]	 Loss: 4.761257171630859
Step [350/390]	 Loss: 4.698040962219238


 88%|████████▊ | 88/100 [1:55:49<15:27, 77.25s/it]

Epoch [87/100]	 Loss: 4.764029164192004	 lr: 0.00617
Step [0/390]	 Loss: 4.831154823303223
Step [50/390]	 Loss: 4.819319725036621
Step [100/390]	 Loss: 4.761665344238281
Step [150/390]	 Loss: 4.786736488342285
Step [200/390]	 Loss: 4.790834426879883
Step [250/390]	 Loss: 4.685606956481934
Step [300/390]	 Loss: 4.669759750366211
Step [350/390]	 Loss: 4.733574390411377


 89%|████████▉ | 89/100 [1:57:06<14:10, 77.29s/it]

Epoch [88/100]	 Loss: 4.766204296014248	 lr: 0.00527
Step [0/390]	 Loss: 4.771656036376953
Step [50/390]	 Loss: 4.8927435874938965
Step [100/390]	 Loss: 4.701238632202148
Step [150/390]	 Loss: 4.8319244384765625
Step [200/390]	 Loss: 4.7694268226623535
Step [250/390]	 Loss: 4.786733627319336
Step [300/390]	 Loss: 4.700782775878906
Step [350/390]	 Loss: 4.729974746704102


 90%|█████████ | 90/100 [1:58:23<12:50, 77.08s/it]

Epoch [89/100]	 Loss: 4.765175380462256	 lr: 0.00443
Step [0/390]	 Loss: 4.781793117523193
Step [50/390]	 Loss: 4.828197002410889
Step [100/390]	 Loss: 4.75984525680542
Step [150/390]	 Loss: 4.780405521392822
Step [200/390]	 Loss: 4.768982410430908
Step [250/390]	 Loss: 4.769301891326904
Step [300/390]	 Loss: 4.717963218688965
Step [350/390]	 Loss: 4.812739849090576


 91%|█████████ | 91/100 [1:59:41<11:35, 77.31s/it]

Epoch [90/100]	 Loss: 4.7655748122777695	 lr: 0.00367
Step [0/390]	 Loss: 4.755336284637451
Step [50/390]	 Loss: 4.719883918762207
Step [100/390]	 Loss: 4.687856197357178
Step [150/390]	 Loss: 4.777523040771484
Step [200/390]	 Loss: 4.780395984649658
Step [250/390]	 Loss: 4.720319747924805
Step [300/390]	 Loss: 4.6543498039245605
Step [350/390]	 Loss: 4.796549320220947


 92%|█████████▏| 92/100 [2:00:59<10:19, 77.47s/it]

Epoch [91/100]	 Loss: 4.767754020446386	 lr: 0.00298
Step [0/390]	 Loss: 4.803224086761475
Step [50/390]	 Loss: 4.781676292419434
Step [100/390]	 Loss: 4.70796537399292
Step [150/390]	 Loss: 4.751513957977295
Step [200/390]	 Loss: 4.7735209465026855
Step [250/390]	 Loss: 4.763165473937988
Step [300/390]	 Loss: 4.757328987121582
Step [350/390]	 Loss: 4.702865123748779


 93%|█████████▎| 93/100 [2:02:16<09:01, 77.32s/it]

Epoch [92/100]	 Loss: 4.763536107234466	 lr: 0.00236
Step [0/390]	 Loss: 4.699700832366943
Step [50/390]	 Loss: 4.757369041442871
Step [100/390]	 Loss: 4.716769218444824
Step [150/390]	 Loss: 4.726726055145264
Step [200/390]	 Loss: 4.7718048095703125
Step [250/390]	 Loss: 4.8201775550842285
Step [300/390]	 Loss: 4.8058013916015625
Step [350/390]	 Loss: 4.761889457702637


 94%|█████████▍| 94/100 [2:03:33<07:43, 77.21s/it]

Epoch [93/100]	 Loss: 4.768422640286959	 lr: 0.00181
Step [0/390]	 Loss: 4.813785552978516
Step [50/390]	 Loss: 4.794251441955566
Step [100/390]	 Loss: 4.760996341705322
Step [150/390]	 Loss: 4.707696437835693
Step [200/390]	 Loss: 4.752285480499268
Step [250/390]	 Loss: 4.820991516113281
Step [300/390]	 Loss: 4.734037399291992
Step [350/390]	 Loss: 4.768703460693359


 95%|█████████▌| 95/100 [2:04:50<06:26, 77.33s/it]

Epoch [94/100]	 Loss: 4.760280918463683	 lr: 0.00133
Step [0/390]	 Loss: 4.750439167022705
Step [50/390]	 Loss: 4.780170440673828
Step [100/390]	 Loss: 4.740994930267334
Step [150/390]	 Loss: 4.8375115394592285
Step [200/390]	 Loss: 4.804592132568359
Step [250/390]	 Loss: 4.797694683074951
Step [300/390]	 Loss: 4.727060317993164
Step [350/390]	 Loss: 4.752927303314209


 96%|█████████▌| 96/100 [2:06:08<05:09, 77.36s/it]

Epoch [95/100]	 Loss: 4.764671083597037	 lr: 0.00092
Step [0/390]	 Loss: 4.772167682647705
Step [50/390]	 Loss: 4.737239837646484
Step [100/390]	 Loss: 4.656899452209473
Step [150/390]	 Loss: 4.755734443664551
Step [200/390]	 Loss: 4.846535682678223
Step [250/390]	 Loss: 4.747959136962891
Step [300/390]	 Loss: 4.636353492736816
Step [350/390]	 Loss: 4.722289562225342


 97%|█████████▋| 97/100 [2:07:25<03:51, 77.32s/it]

Epoch [96/100]	 Loss: 4.764199520991399	 lr: 0.00059
Step [0/390]	 Loss: 4.858349800109863
Step [50/390]	 Loss: 4.832698345184326
Step [100/390]	 Loss: 4.842981338500977
Step [150/390]	 Loss: 4.78468132019043
Step [200/390]	 Loss: 4.869058609008789
Step [250/390]	 Loss: 4.802465915679932
Step [300/390]	 Loss: 4.757975101470947
Step [350/390]	 Loss: 4.763730525970459


 98%|█████████▊| 98/100 [2:08:43<02:34, 77.48s/it]

Epoch [97/100]	 Loss: 4.76603589791518	 lr: 0.00033
Step [0/390]	 Loss: 4.845604419708252
Step [50/390]	 Loss: 4.773802757263184
Step [100/390]	 Loss: 4.795609951019287
Step [150/390]	 Loss: 4.7747931480407715
Step [200/390]	 Loss: 4.7818145751953125
Step [250/390]	 Loss: 4.7421746253967285
Step [300/390]	 Loss: 4.835600852966309
Step [350/390]	 Loss: 4.788452625274658


 99%|█████████▉| 99/100 [2:10:01<01:17, 77.82s/it]

Epoch [98/100]	 Loss: 4.769894914138011	 lr: 0.00015
Step [0/390]	 Loss: 4.8128790855407715
Step [50/390]	 Loss: 4.794370174407959
Step [100/390]	 Loss: 4.7896552085876465
Step [150/390]	 Loss: 4.770269870758057
Step [200/390]	 Loss: 4.731128692626953
Step [250/390]	 Loss: 4.7817206382751465
Step [300/390]	 Loss: 4.851526260375977
Step [350/390]	 Loss: 4.667038917541504


100%|██████████| 100/100 [2:11:19<00:00, 78.79s/it]

Epoch [99/100]	 Loss: 4.764617267021766	 lr: 4e-05





In [23]:
from simclr.modules import LogisticRegression

def train(loader, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    for step, (x, y) in enumerate(loader):
        optimizer.zero_grad()

        x = x.to(DEVICE)
        y = y.to(DEVICE)

        output = model(x)
        loss = criterion(output, y)

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss.backward()
        optimizer.step()

        loss_epoch += loss.item()

    return loss_epoch, accuracy_epoch

In [24]:
def test(loader, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    model.eval()
    for step, (x, y) in enumerate(loader):
        model.zero_grad()

        x = x.to(DEVICE)
        y = y.to(DEVICE)

        output = model(x)
        loss = criterion(output, y)

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss_epoch += loss.item()

    return loss_epoch, accuracy_epoch

In [25]:
epoch_num = 100
logistic_epochs = 800

In [26]:
train_transform = tim_transform
test_transform = train_transform

train_dataset = torchvision.datasets.CIFAR10(root = '/content/sample_data', transform = train_transform, download = True)
test_dataset = torchvision.datasets.CIFAR10(root = '/content/sample_data', transform = test_transform)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = batch_size,
    shuffle = False,
    drop_last=True,
    num_workers = 2,
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle = False,
    drop_last=True,
    num_workers = 2,
)

Files already downloaded and verified


In [27]:
encoder = torchvision.models.resnet18()

n_features = encoder.fc.in_features
simclr_model = SimCLR(encoder, projection_dim, n_features)
model_fp = os.path.join(
    model_path, "checkpoint_{}.tar".format(epochs)
)
simclr_model.load_state_dict(torch.load(model_fp, map_location=DEVICE.type))
simclr_model.to(DEVICE)

SimCLR(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn

In [28]:
## Logistic Regression
n_classes = 10 
model = LogisticRegression(simclr_model.n_features, n_classes)
model = model.to(DEVICE)

In [29]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

In [30]:
def inference(loader, simclr_model, device):
    feature_vector = []
    labels_vector = []
    for step, (x, y) in enumerate(loader):
        x = x.to(device)

        # get encoding
        with torch.no_grad():
            h, _, z, _, _, _, _, _ = simclr_model(x, x)

        h = h.detach()

        feature_vector.extend(h.cpu().detach().numpy())
        labels_vector.extend(y.numpy())

        if step % 20 == 0:
            print(f"Step [{step}/{len(loader)}]\t Computing features...")

    feature_vector = np.array(feature_vector)
    labels_vector = np.array(labels_vector)
    print("Features shape {}".format(feature_vector.shape))
    return feature_vector, labels_vector


def get_features(context_model, train_loader, test_loader, device):
    train_X, train_y = inference(train_loader, context_model, device)
    test_X, test_y = inference(test_loader, context_model, device)
    return train_X, train_y, test_X, test_y


def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test, batch_size):
    train = torch.utils.data.TensorDataset(
        torch.from_numpy(X_train), torch.from_numpy(y_train)
    )
    train_loader = torch.utils.data.DataLoader(
        train, batch_size=batch_size, shuffle=False
    )

    test = torch.utils.data.TensorDataset(
        torch.from_numpy(X_test), torch.from_numpy(y_test)
    )
    test_loader = torch.utils.data.DataLoader(
        test, batch_size=batch_size, shuffle=False
    )
    return train_loader, test_loader

In [31]:
print("### Creating features from pre-trained context model ###")
(train_X, train_y, test_X, test_y) = get_features(
    simclr_model, train_loader, test_loader, DEVICE
)

arr_train_loader, arr_test_loader = create_data_loaders_from_arrays(
    train_X, train_y, test_X, test_y, batch_size
)

### Creating features from pre-trained context model ###
Step [0/390]	 Computing features...
Step [20/390]	 Computing features...
Step [40/390]	 Computing features...
Step [60/390]	 Computing features...
Step [80/390]	 Computing features...
Step [100/390]	 Computing features...
Step [120/390]	 Computing features...
Step [140/390]	 Computing features...
Step [160/390]	 Computing features...
Step [180/390]	 Computing features...
Step [200/390]	 Computing features...
Step [220/390]	 Computing features...
Step [240/390]	 Computing features...
Step [260/390]	 Computing features...
Step [280/390]	 Computing features...
Step [300/390]	 Computing features...
Step [320/390]	 Computing features...
Step [340/390]	 Computing features...
Step [360/390]	 Computing features...
Step [380/390]	 Computing features...
Features shape (49920, 512)
Step [0/390]	 Computing features...
Step [20/390]	 Computing features...
Step [40/390]	 Computing features...
Step [60/390]	 Computing features...
Step [80/390]	

In [32]:
for epoch in tqdm(range(logistic_epochs)):
    loss_epoch, accuracy_epoch = train(arr_train_loader, model, criterion, optimizer)
    
    if epoch % 10 == 0:
      print(f"Epoch [{epoch}/{logistic_epochs}]\t Loss: {loss_epoch / len(train_loader)}\t Accuracy: {accuracy_epoch / len(train_loader)}")


# final testing
loss_epoch, accuracy_epoch = test(
    arr_test_loader, model, criterion, optimizer
)
print(
    f"[FINAL]\t Loss: {loss_epoch / len(test_loader)}\t Accuracy: {accuracy_epoch / len(test_loader)}"
)

  0%|          | 1/800 [00:00<09:41,  1.37it/s]

Epoch [0/800]	 Loss: 1.6693740520721827	 Accuracy: 0.39503205128205127


  1%|▏         | 11/800 [00:08<09:46,  1.35it/s]

Epoch [10/800]	 Loss: 1.4587524163417327	 Accuracy: 0.47584134615384616


  3%|▎         | 21/800 [00:15<09:09,  1.42it/s]

Epoch [20/800]	 Loss: 1.4294870052582178	 Accuracy: 0.4858573717948718


  4%|▍         | 31/800 [00:22<08:56,  1.43it/s]

Epoch [30/800]	 Loss: 1.414025585162334	 Accuracy: 0.49200721153846155


  5%|▌         | 41/800 [00:29<08:54,  1.42it/s]

Epoch [40/800]	 Loss: 1.4041841626167297	 Accuracy: 0.495713141025641


  6%|▋         | 51/800 [00:36<09:11,  1.36it/s]

Epoch [50/800]	 Loss: 1.3973153144885333	 Accuracy: 0.4980168269230769


  8%|▊         | 61/800 [00:44<08:35,  1.43it/s]

Epoch [60/800]	 Loss: 1.3922454323524083	 Accuracy: 0.49973958333333335


  9%|▉         | 71/800 [00:51<08:34,  1.42it/s]

Epoch [70/800]	 Loss: 1.3883601051110488	 Accuracy: 0.5015424679487179


 10%|█         | 81/800 [00:58<08:20,  1.44it/s]

Epoch [80/800]	 Loss: 1.3853013329016857	 Accuracy: 0.5026642628205128


 11%|█▏        | 91/800 [01:05<08:15,  1.43it/s]

Epoch [90/800]	 Loss: 1.382843977977068	 Accuracy: 0.5034655448717948


 13%|█▎        | 101/800 [01:13<09:04,  1.28it/s]

Epoch [100/800]	 Loss: 1.3808383107185365	 Accuracy: 0.5044270833333333


 14%|█▍        | 111/800 [01:20<08:06,  1.42it/s]

Epoch [110/800]	 Loss: 1.3791802766995551	 Accuracy: 0.5048277243589744


 15%|█▌        | 121/800 [01:27<08:01,  1.41it/s]

Epoch [120/800]	 Loss: 1.3777951329182356	 Accuracy: 0.5052884615384615


 16%|█▋        | 131/800 [01:34<08:07,  1.37it/s]

Epoch [130/800]	 Loss: 1.3766276130309472	 Accuracy: 0.5056490384615384


 18%|█▊        | 141/800 [01:41<08:11,  1.34it/s]

Epoch [140/800]	 Loss: 1.375635998065655	 Accuracy: 0.5060496794871795


 19%|█▉        | 151/800 [01:49<07:43,  1.40it/s]

Epoch [150/800]	 Loss: 1.3747881867946723	 Accuracy: 0.5065104166666666


 20%|██        | 161/800 [01:56<07:33,  1.41it/s]

Epoch [160/800]	 Loss: 1.3740591009457905	 Accuracy: 0.5067107371794872


 21%|██▏       | 171/800 [02:03<07:18,  1.44it/s]

Epoch [170/800]	 Loss: 1.373428810865451	 Accuracy: 0.5068709935897436


 23%|██▎       | 181/800 [02:10<07:18,  1.41it/s]

Epoch [180/800]	 Loss: 1.3728814280950106	 Accuracy: 0.5074919871794872


 24%|██▍       | 191/800 [02:18<07:35,  1.34it/s]

Epoch [190/800]	 Loss: 1.3724040437967349	 Accuracy: 0.5076522435897436


 25%|██▌       | 201/800 [02:25<07:06,  1.41it/s]

Epoch [200/800]	 Loss: 1.3719861015295371	 Accuracy: 0.5077323717948717


 26%|██▋       | 211/800 [02:32<07:01,  1.40it/s]

Epoch [210/800]	 Loss: 1.371618914604187	 Accuracy: 0.5079727564102564


 28%|██▊       | 221/800 [02:39<07:01,  1.37it/s]

Epoch [220/800]	 Loss: 1.3712953087611077	 Accuracy: 0.5080729166666667


 29%|██▉       | 231/800 [02:46<06:47,  1.40it/s]

Epoch [230/800]	 Loss: 1.3710092281683897	 Accuracy: 0.5082131410256411


 30%|███       | 241/800 [02:53<06:31,  1.43it/s]

Epoch [240/800]	 Loss: 1.3707556446393332	 Accuracy: 0.5083934294871795


 31%|███▏      | 251/800 [03:01<06:54,  1.32it/s]

Epoch [250/800]	 Loss: 1.3705302956776741	 Accuracy: 0.5083133012820513


 33%|███▎      | 261/800 [03:08<06:16,  1.43it/s]

Epoch [260/800]	 Loss: 1.3703295533473676	 Accuracy: 0.5083733974358975


 34%|███▍      | 271/800 [03:15<06:59,  1.26it/s]

Epoch [270/800]	 Loss: 1.370150320040874	 Accuracy: 0.5083934294871795


 35%|███▌      | 281/800 [03:22<06:46,  1.28it/s]

Epoch [280/800]	 Loss: 1.3699899710141696	 Accuracy: 0.5087940705128206


 36%|███▋      | 291/800 [03:30<06:07,  1.38it/s]

Epoch [290/800]	 Loss: 1.3698462137809166	 Accuracy: 0.5088141025641025


 38%|███▊      | 301/800 [03:37<05:49,  1.43it/s]

Epoch [300/800]	 Loss: 1.3697170994220635	 Accuracy: 0.5089743589743589


 39%|███▉      | 311/800 [03:44<05:42,  1.43it/s]

Epoch [310/800]	 Loss: 1.3696009125465	 Accuracy: 0.5091346153846154


 40%|████      | 321/800 [03:51<05:27,  1.46it/s]

Epoch [320/800]	 Loss: 1.369496198800894	 Accuracy: 0.5092347756410256


 41%|████▏     | 331/800 [03:58<05:32,  1.41it/s]

Epoch [330/800]	 Loss: 1.3694016795891981	 Accuracy: 0.509395032051282


 43%|████▎     | 341/800 [04:05<05:20,  1.43it/s]

Epoch [340/800]	 Loss: 1.3693162202835083	 Accuracy: 0.5093149038461539


 44%|████▍     | 351/800 [04:12<05:10,  1.45it/s]

Epoch [350/800]	 Loss: 1.369238831446721	 Accuracy: 0.5091346153846154


 45%|████▌     | 361/800 [04:20<05:44,  1.28it/s]

Epoch [360/800]	 Loss: 1.3691686752514962	 Accuracy: 0.5091546474358974


 46%|████▋     | 371/800 [04:27<04:55,  1.45it/s]

Epoch [370/800]	 Loss: 1.3691049716411492	 Accuracy: 0.5092748397435898


 48%|████▊     | 381/800 [04:34<04:46,  1.46it/s]

Epoch [380/800]	 Loss: 1.3690470573229667	 Accuracy: 0.5090745192307692


 49%|████▉     | 391/800 [04:41<04:55,  1.38it/s]

Epoch [390/800]	 Loss: 1.3689943518394079	 Accuracy: 0.5091546474358974


 50%|█████     | 401/800 [04:48<04:42,  1.41it/s]

Epoch [400/800]	 Loss: 1.368946310495719	 Accuracy: 0.5091145833333334


 51%|█████▏    | 411/800 [04:55<04:32,  1.43it/s]

Epoch [410/800]	 Loss: 1.3689025013874738	 Accuracy: 0.5091746794871795


 53%|█████▎    | 421/800 [05:03<04:37,  1.36it/s]

Epoch [420/800]	 Loss: 1.3688624608211029	 Accuracy: 0.5091746794871795


 54%|█████▍    | 431/800 [05:10<04:12,  1.46it/s]

Epoch [430/800]	 Loss: 1.3688258696825077	 Accuracy: 0.5092147435897436


 55%|█████▌    | 441/800 [05:17<04:07,  1.45it/s]

Epoch [440/800]	 Loss: 1.3687923639248578	 Accuracy: 0.5092748397435898


 56%|█████▋    | 451/800 [05:24<04:19,  1.34it/s]

Epoch [450/800]	 Loss: 1.3687616852613596	 Accuracy: 0.5092748397435898


 58%|█████▊    | 461/800 [05:31<03:56,  1.43it/s]

Epoch [460/800]	 Loss: 1.3687335258875137	 Accuracy: 0.5093549679487179


 59%|█████▉    | 471/800 [05:38<03:46,  1.45it/s]

Epoch [470/800]	 Loss: 1.3687076978194408	 Accuracy: 0.5094150641025641


 60%|██████    | 481/800 [05:45<03:48,  1.39it/s]

Epoch [480/800]	 Loss: 1.3686839647782154	 Accuracy: 0.5095753205128205


 61%|██████▏   | 491/800 [05:52<03:58,  1.29it/s]

Epoch [490/800]	 Loss: 1.3686621384742932	 Accuracy: 0.5095552884615384


 63%|██████▎   | 501/800 [06:00<03:36,  1.38it/s]

Epoch [500/800]	 Loss: 1.3686420260331569	 Accuracy: 0.5094350961538462


 64%|██████▍   | 511/800 [06:07<03:23,  1.42it/s]

Epoch [510/800]	 Loss: 1.3686235155814732	 Accuracy: 0.5095352564102564


 65%|██████▌   | 521/800 [06:14<03:15,  1.43it/s]

Epoch [520/800]	 Loss: 1.3686064573434682	 Accuracy: 0.5094551282051282


 66%|██████▋   | 531/800 [06:21<03:13,  1.39it/s]

Epoch [530/800]	 Loss: 1.368590708879324	 Accuracy: 0.5094951923076924


 68%|██████▊   | 541/800 [06:28<03:07,  1.38it/s]

Epoch [540/800]	 Loss: 1.3685761570930481	 Accuracy: 0.509375


 69%|██████▉   | 551/800 [06:35<02:52,  1.44it/s]

Epoch [550/800]	 Loss: 1.3685627115078463	 Accuracy: 0.5093549679487179


 70%|███████   | 561/800 [06:43<02:50,  1.40it/s]

Epoch [560/800]	 Loss: 1.3685502712543194	 Accuracy: 0.5094551282051282


 71%|███████▏  | 571/800 [06:49<02:34,  1.48it/s]

Epoch [570/800]	 Loss: 1.3685387620559106	 Accuracy: 0.5094150641025641


 73%|███████▎  | 581/800 [06:57<02:41,  1.35it/s]

Epoch [580/800]	 Loss: 1.3685280888508529	 Accuracy: 0.509395032051282


 74%|███████▍  | 591/800 [07:04<02:28,  1.41it/s]

Epoch [590/800]	 Loss: 1.3685181984534631	 Accuracy: 0.5094350961538462


 75%|███████▌  | 601/800 [07:11<02:18,  1.43it/s]

Epoch [600/800]	 Loss: 1.3685090318704263	 Accuracy: 0.509334935897436


 76%|███████▋  | 611/800 [07:18<02:10,  1.45it/s]

Epoch [610/800]	 Loss: 1.3685005041269156	 Accuracy: 0.5092748397435898


 78%|███████▊  | 621/800 [07:25<02:11,  1.36it/s]

Epoch [620/800]	 Loss: 1.3684925730411823	 Accuracy: 0.5091746794871795


 79%|███████▉  | 631/800 [07:33<02:02,  1.38it/s]

Epoch [630/800]	 Loss: 1.3684852089637365	 Accuracy: 0.5092147435897436


 80%|████████  | 641/800 [07:40<01:52,  1.42it/s]

Epoch [640/800]	 Loss: 1.3684783608485491	 Accuracy: 0.5091546474358974


 81%|████████▏ | 651/800 [07:47<01:47,  1.39it/s]

Epoch [650/800]	 Loss: 1.3684719745929426	 Accuracy: 0.5091346153846154


 83%|████████▎ | 661/800 [07:54<01:35,  1.45it/s]

Epoch [660/800]	 Loss: 1.3684660122944758	 Accuracy: 0.5091746794871795


 84%|████████▍ | 671/800 [08:02<01:35,  1.35it/s]

Epoch [670/800]	 Loss: 1.3684604632548796	 Accuracy: 0.5091346153846154


 85%|████████▌ | 681/800 [08:09<01:29,  1.33it/s]

Epoch [680/800]	 Loss: 1.3684552794847733	 Accuracy: 0.5090945512820513


 86%|████████▋ | 691/800 [08:16<01:15,  1.44it/s]

Epoch [690/800]	 Loss: 1.3684504191080729	 Accuracy: 0.5090945512820513


 88%|████████▊ | 701/800 [08:23<01:10,  1.40it/s]

Epoch [700/800]	 Loss: 1.3684459062723013	 Accuracy: 0.5090544871794872


 89%|████████▉ | 711/800 [08:30<01:15,  1.17it/s]

Epoch [710/800]	 Loss: 1.3684416547799723	 Accuracy: 0.5090144230769231


 90%|█████████ | 721/800 [08:37<00:54,  1.44it/s]

Epoch [720/800]	 Loss: 1.3684376817483168	 Accuracy: 0.5090544871794872


 91%|█████████▏| 731/800 [08:45<00:48,  1.41it/s]

Epoch [730/800]	 Loss: 1.3684339593618344	 Accuracy: 0.5090945512820513


 93%|█████████▎| 741/800 [08:52<00:41,  1.41it/s]

Epoch [740/800]	 Loss: 1.3684304390198145	 Accuracy: 0.5090945512820513


 94%|█████████▍| 751/800 [08:59<00:33,  1.45it/s]

Epoch [750/800]	 Loss: 1.3684271641266652	 Accuracy: 0.5090745192307692


 95%|█████████▌| 761/800 [09:06<00:28,  1.36it/s]

Epoch [760/800]	 Loss: 1.3684240888326595	 Accuracy: 0.5090344551282051


 96%|█████████▋| 771/800 [09:13<00:20,  1.39it/s]

Epoch [770/800]	 Loss: 1.3684211862392914	 Accuracy: 0.5090144230769231


 98%|█████████▊| 781/800 [09:20<00:13,  1.43it/s]

Epoch [780/800]	 Loss: 1.3684184493162692	 Accuracy: 0.5090344551282051


 99%|█████████▉| 791/800 [09:27<00:06,  1.44it/s]

Epoch [790/800]	 Loss: 1.368415880508912	 Accuracy: 0.5090544871794872


100%|██████████| 800/800 [09:34<00:00,  1.39it/s]


[FINAL]	 Loss: 1.3613630022758092	 Accuracy: 0.5117387820512821
