In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
print(gpu_info)

from datetime import datetime
from functools import partial
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet
from tqdm import tqdm
import argparse
import json
import math
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import lightly.models as models
import lightly.loss as loss

os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

Mon Jun 28 15:48:51 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.80       Driver Version: 460.80       CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  TITAN V             Off  | 00000000:3B:00.0 Off |                  N/A |
| 28%   41C    P2    43W / 250W |   4515MiB / 12066MiB |     79%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  TITAN V             Off  | 00000000:5E:00.0 Off |                  N/A |
| 29%   42C    P2    43W / 250W |   4515MiB / 12066MiB |     16%      Default |
|       

## Create dataloader, model

In [6]:
parser = argparse.ArgumentParser(description='Train NNCLR on CelebA targetlist')

parser.add_argument('-a', '--arch', default='resnet18')

# lr: 0.06 for batch 512 (or 0.03 for batch 256)
parser.add_argument('--lr', '--learning-rate', default=0.06, type=float, metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--epochs', default=800, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--schedule', default=[300, 600], nargs='*', type=int, help='learning rate schedule (when to drop lr by 10x); does not take effect if --cos is on')
parser.add_argument('--cos', action='store_true', help='use cosine lr schedule')

parser.add_argument('--batch-size', default=256, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay')

parser.add_argument('--bn-splits', default=8, type=int, help='simulate multi-gpu behavior of BatchNorm in one gpu; 1 is SyncBatchNorm in multi-gpu')

parser.add_argument('--symmetric', action='store_true', help='use a symmetric loss function that backprops to both crops')

# knn monitor
parser.add_argument('--knn-k', default=1, type=int, help='k in kNN monitor')
parser.add_argument('--knn-t', default=0.1, type=float, help='softmax temperature in kNN monitor; could be different with moco-t')

# utils
parser.add_argument('--resume', default='', 
                    type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--results-dir', default='', type=str, metavar='PATH', help='path to cache (default: none)')

'''
args = parser.parse_args()  # running in command line
'''
args = parser.parse_args('')  # running in ipynb

# set command line arguments here when running in ipynb
args.epochs = 800
args.cos = True
args.schedule = []  # cos in use
args.symmetric = False
if args.results_dir == '':
    args.results_dir = '../output/CelebA-NNCLR-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-moco")

print(args)

Namespace(arch='resnet18', batch_size=256, bn_splits=8, cos=True, epochs=800, knn_k=1, knn_t=0.1, lr=0.06, results_dir='../output/CelebA-NNCLR-2021-06-28-15-49-22-moco', resume='', schedule=[], symmetric=False, wd=0.0005)


In [7]:
import pickle
from PIL import Image, ImageOps
from typing import List, Union, Callable


class CelebAPair(torchvision.datasets.CelebA):
    """CIFAR10 Dataset.
    """
    def __getitem__(self, index):

        img = Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))

        if self.transform is not None:
            im_1 = self.transform(img)
            im_2 = self.transform(img)

        return im_1, im_2

    
train_transform = transforms.Compose([
    transforms.CenterCrop(128),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

test_transform = transforms.Compose([
    transforms.CenterCrop(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

In [8]:
# data prepare
train_data = CelebAPair(root = '.', 
                        split = 'train', 
                        target_type = 'identity', 
                        transform = train_transform, 
                        target_transform = None, 
                        download=False)

train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, 
                          num_workers=16, pin_memory=True, drop_last=True)

In [9]:
memory_data = torchvision.datasets.CelebA(root = '.', 
                            split = 'train', 
                            target_type = 'identity', 
                            transform = test_transform, 
                            target_transform = None, 
                            download=False) 

memory_loader = DataLoader(memory_data, batch_size=args.batch_size, 
                           shuffle=False, num_workers=16, pin_memory=True)

In [10]:
test_data = torchvision.datasets.CelebA(root = '.', 
                            split = 'valid', 
                            target_type = 'identity', 
                            transform = test_transform, 
                            target_transform = None, 
                            download=False) 

test_loader = DataLoader(test_data, batch_size=args.batch_size, 
                         shuffle=False, num_workers=16, pin_memory=True)

In [11]:
# add labeldict because identity is not continuous
labeldict = {x:i for i, x in enumerate(set(memory_data.identity[:, 0].numpy()).union(set(test_data.identity[:, 0].numpy())))}

In [12]:
len(labeldict.values())

9177

## Model

- For other Siamese SOTA

In [13]:
# use a resnet50 backbone
# resnet = torchvision.models.resnet.resnet18()
# resnet = torch.nn.Sequential(*list(resnet.children())[:-1])

# # build the simsiam model
# model = models.SimCLR(resnet, num_ftrs=512)
# model = model.cuda()

# # use a criterion for self-supervised learning
# criterion = loss.NTXentLoss(temperature=0.5)

- For NNCLR

In [20]:
import torchvision
import torch.nn as nn
from lightly.models import NNCLR
from lightly.loss import NTXentLoss
from lightly.models.modules import NNMemoryBankModule

resnet = torchvision.models.resnet18()
backbone = nn.Sequential(
    *list(resnet.children())[:-1],
    nn.AdaptiveAvgPool2d(1),
)

# NNCLR
model = NNCLR(backbone)
model = torch.nn.DataParallel(model)
model = model.cuda()
criterion = NTXentLoss()

nn_replacer = NNMemoryBankModule(size=2 ** 16)

In [15]:
# train for one epoch

######################### For other models ##########################
# def train(net, data_loader, train_optimizer, epoch, args):
#     net.train()
#     adjust_learning_rate(optimizer, epoch, args)

#     total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
#     for im_1, im_2 in train_bar:
#         im_1, im_2 = im_1.cuda(non_blocking=True), im_2.cuda(non_blocking=True)

#         y0, y1 = net(im_1, im_2)
#         # backpropagation
#         loss = criterion(y0, y1)
        
#         train_optimizer.zero_grad()
#         loss.backward()
#         train_optimizer.step()

#         total_num += data_loader.batch_size
#         total_loss += loss.item() * data_loader.batch_size
#         train_bar.set_description('Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}'.format(epoch, args.epochs, optimizer.param_groups[0]['lr'], total_loss / total_num))

#     return total_loss / total_num
##############################################################################

######################### For NNCLR ##########################
def train(net, data_loader, train_optimizer, epoch, args):
    net.train()
    adjust_learning_rate(optimizer, epoch, args)

    total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
    for im_1, im_2 in train_bar:
        im_1, im_2 = im_1.cuda(non_blocking=True), im_2.cuda(non_blocking=True)

        # forward pass
        (z0, p0), (z1, p1) = model(im_1, im_2)
        z0 = nn_replacer(z0.detach(), update=False)
        z1 = nn_replacer(z1.detach(), update=True)
        loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))
        
        train_optimizer.zero_grad()
        loss.backward()
        train_optimizer.step()

        total_num += data_loader.batch_size
        total_loss += loss.item() * data_loader.batch_size
        train_bar.set_description('Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}'.format(epoch, args.epochs, optimizer.param_groups[0]['lr'], total_loss / total_num))

    return total_loss / total_num
##############################################################################

# lr scheduler for training
def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate based on schedule"""
    lr = args.lr
    if args.cos:  # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
    else:  # stepwise lr schedule
        for milestone in args.schedule:
            lr *= 0.1 if epoch >= milestone else 1.
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [16]:
# test using a knn monitor
def test(net, memory_data_loader, test_data_loader, epoch, args):
    net.eval()
    classes = len(labeldict.keys())
    print('Number of classes {}'.format(classes))
    
    total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
    with torch.no_grad():
        # generate feature bank
        for data, target in tqdm(memory_data_loader, desc='Feature extracting'):
            feature = net(data.cuda(non_blocking=True)).squeeze().detach().cpu()
            feature = F.normalize(feature, dim=1)
            feature_bank.append(feature)
            
        # [D, N]
        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
        # [N]
        feature_labels = torch.tensor([labeldict[x] for x in memory_data_loader.dataset.identity[:, 0].numpy()], 
                                      device=feature_bank.device)
#         print(feature_labels.shape)
        
        # loop test data to predict the label by weighted knn search
        test_bar = tqdm(test_data_loader)
        for data, target in test_bar:
            target = torch.tensor([labeldict[x] for x in target.numpy()], device=feature_labels.device)
            data = data.cuda(non_blocking=True)
            feature = net(data).squeeze().detach().cpu()
            feature = F.normalize(feature, dim=1)
            
            pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, args.knn_k, args.knn_t)

            total_num += data.size(0)
            total_top1 += (pred_labels[:, 0] == target).float().sum().item()
            test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}%'.format(epoch, args.epochs, total_top1 / total_num * 100))

    return total_top1 / total_num * 100

# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
    # compute cos similarity between each feature vector and feature bank ---> [B, N]
    sim_matrix = torch.mm(feature, feature_bank)
    # [B, K]
    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
    # [B, K]
    sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
    sim_weight = (sim_weight / knn_t).exp()

    # counts for each class
    one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
    # [B*K, C]
    one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
    # weighted score ---> [B, C]
    pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)

    pred_labels = pred_scores.argsort(dim=-1, descending=True)
    return pred_labels

In [17]:
# define optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9)

# load model if resume
epoch_start = 1
if args.resume is not '':
    checkpoint = torch.load(args.resume)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    epoch_start = checkpoint['epoch'] + 1
    print('Loaded from: {}'.format(args.resume))

In [18]:
# logging
results = {'train_loss': [], 'test_acc@1': []}
if not os.path.exists(args.results_dir):
    os.mkdir(args.results_dir)
    
# dump args
with open(args.results_dir + '/args.json', 'w') as fid:
    json.dump(args.__dict__, fid, indent=2)


In [None]:
# training loop
for epoch in range(epoch_start, args.epochs + 1):
    test_acc_1 = test(model.module.backbone, memory_loader, test_loader, epoch, args)
    train_loss = train(model, train_loader, optimizer, epoch, args)
    results['train_loss'].append(train_loss)
    test_acc_1 = test(model.module.backbone, memory_loader, test_loader, epoch, args)
    results['test_acc@1'].append(test_acc_1)
    # save statistics
    data_frame = pd.DataFrame(data=results, index=range(epoch_start, epoch + 1))
    data_frame.to_csv(args.results_dir + '/log.csv', index_label='epoch')
    # save model
    torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 
                'optimizer' : optimizer.state_dict(),}, args.results_dir + '/model_last.pth')

Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:06<00:00,  9.63it/s]
Test Epoch: [1/800] Acc@1:0.00%: 100%|██████████| 78/78 [01:57<00:00,  1.50s/it]
Train Epoch: [1/800], lr: 0.060000, Loss: 6.4864: 100%|██████████| 635/635 [04:35<00:00,  2.31it/s]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:07<00:00,  9.48it/s]
Test Epoch: [1/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:01<00:00,  1.55s/it]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:17<00:00,  8.17it/s]
Test Epoch: [2/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:13<00:00,  1.71s/it]
Train Epoch: [2/800], lr: 0.059999, Loss: 6.4857: 100%|██████████| 635/635 [04:15<00:00,  2.49it/s]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:10<00:00,  9.05it/s]
Test Epoch: [2/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:09<00:00,  1.65s/it]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:17<00:00,  8.25it/s]
Test Epoch: [3/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:14<00:00,  1.73s/it]
Train Epoch: [3/800], lr: 0.059998, Loss: 6.4862: 100%|██████████| 635/635 [04:26<00:00,  2.38it/s]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:07<00:00,  9.46it/s]
Test Epoch: [3/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:06<00:00,  1.62s/it]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:05<00:00,  9.76it/s]
Test Epoch: [4/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:10<00:00,  1.67s/it]
Train Epoch: [4/800], lr: 0.059996, Loss: 6.4858: 100%|██████████| 635/635 [04:31<00:00,  2.34it/s]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:09<00:00,  9.16it/s]
Test Epoch: [4/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:12<00:00,  1.70s/it]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:12<00:00,  8.82it/s]
Test Epoch: [5/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:16<00:00,  1.76s/it]
Train Epoch: [5/800], lr: 0.059994, Loss: 6.4859: 100%|██████████| 635/635 [04:08<00:00,  2.55it/s]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:18<00:00,  8.09it/s]
Test Epoch: [5/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:01<00:00,  1.56s/it]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:05<00:00,  9.72it/s]
Test Epoch: [6/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:11<00:00,  1.69s/it]
Train Epoch: [6/800], lr: 0.059992, Loss: 6.4859: 100%|██████████| 635/635 [04:11<00:00,  2.53it/s]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:08<00:00,  9.22it/s]
Test Epoch: [6/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:15<00:00,  1.73s/it]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:11<00:00,  8.85it/s]
Test Epoch: [7/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:00<00:00,  1.54s/it]
Train Epoch: [7/800], lr: 0.059989, Loss: 6.4858: 100%|██████████| 635/635 [04:24<00:00,  2.40it/s]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:15<00:00,  8.41it/s]
Test Epoch: [7/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:09<00:00,  1.67s/it]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:05<00:00,  9.75it/s]
Test Epoch: [8/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:10<00:00,  1.67s/it]
Train Epoch: [8/800], lr: 0.059985, Loss: 6.4854: 100%|██████████| 635/635 [04:10<00:00,  2.53it/s]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:06<00:00,  9.52it/s]
Test Epoch: [8/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:07<00:00,  1.64s/it]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:07<00:00,  9.48it/s]
Test Epoch: [9/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:13<00:00,  1.71s/it]
Train Epoch: [9/800], lr: 0.059981, Loss: 6.4856: 100%|██████████| 635/635 [04:08<00:00,  2.56it/s]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:10<00:00,  9.03it/s]
Test Epoch: [9/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:20<00:00,  1.80s/it]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:10<00:00,  9.06it/s]
Test Epoch: [10/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:06<00:00,  1.63s/it]
Train Epoch: [10/800], lr: 0.059977, Loss: 6.4860: 100%|██████████| 635/635 [04:04<00:00,  2.60it/s]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:11<00:00,  8.91it/s]
Test Epoch: [10/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:12<00:00,  1.69s/it]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:06<00:00,  9.57it/s]
Test Epoch: [11/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:22<00:00,  1.83s/it]
Train Epoch: [11/800], lr: 0.059972, Loss: 6.4858: 100%|██████████| 635/635 [04:05<00:00,  2.59it/s]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:15<00:00,  8.47it/s]
Test Epoch: [11/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:09<00:00,  1.66s/it]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:06<00:00,  9.62it/s]
Test Epoch: [12/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:06<00:00,  1.63s/it]
Train Epoch: [12/800], lr: 0.059967, Loss: 6.4860: 100%|██████████| 635/635 [04:01<00:00,  2.63it/s]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:14<00:00,  8.49it/s]
Test Epoch: [12/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:16<00:00,  1.76s/it]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:15<00:00,  8.42it/s]
Test Epoch: [13/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:05<00:00,  1.61s/it]
Train Epoch: [13/800], lr: 0.059961, Loss: 6.4860: 100%|██████████| 635/635 [04:10<00:00,  2.53it/s]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:05<00:00,  9.66it/s]
Test Epoch: [13/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:13<00:00,  1.71s/it]
Feature extracting:   0%|          | 0/636 [00:00<?, ?it/s]

Number of classes 9177


Feature extracting: 100%|██████████| 636/636 [01:11<00:00,  8.93it/s]
Test Epoch: [14/800] Acc@1:0.00%: 100%|██████████| 78/78 [02:17<00:00,  1.76s/it]
Train Epoch: [14/800], lr: 0.059955, Loss: 6.4859:  70%|██████▉   | 444/635 [03:11<01:08,  2.77it/s]