### Квантует DeepLabV3 MobilenetV3

Стартуем с трейнлупа, который нам выдали pytorch

In [1]:
import datetime
import os
import pickle
import time
from copy import deepcopy
from pathlib import Path

import torch
import torch.utils.data
from torch import nn
from torch.ao.quantization.quantize_fx import convert_fx
from torch.ao.quantization.quantize_fx import fuse_fx
from torch.optim.lr_scheduler import PolynomialLR
from torchvision.models.segmentation import DeepLabV3_MobileNet_V3_Large_Weights, deeplabv3_mobilenet_v3_large, deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
from tqdm import tqdm

import utils
from quantization_utils.fake_quantization import fake_quantization
from quantization_utils.static_quantization import quantize_static
from train import evaluate
from train import get_dataset
from train import train_one_epoch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Вытащил дефолтные аргументы, чтобы не упражняться с argparse в ноутбуке
with Path('./torch_default_args.pickle').open('rb') as file:
    args = pickle.load(file)

In [3]:
# Подобирайте под ваше железо
args.data_path = '/home/d.chudakov/datasets/coco/'
args.epochs = 1
args.batch_size = 32
args.workers = 8

In [4]:
args

Namespace(amp=False, aux_loss=False, backend='pil', batch_size=32, data_path='/home/d.chudakov/datasets/coco/', dataset='coco', device='cuda', dist_url='env://', epochs=1, lr=0.01, lr_warmup_decay=0.01, lr_warmup_epochs=0, lr_warmup_method='linear', model='deeplabv3_mobilenet_v3_large', momentum=0.9, output_dir='.', print_freq=10, resume='', start_epoch=0, test_only=False, use_deterministic_algorithms=False, use_v2=False, weight_decay=0.0001, weights=None, weights_backbone=None, workers=8, world_size=1)

### Сначала просто валидация обычной сетки, прям на гпу

In [5]:
model = deeplabv3_mobilenet_v3_large(weights=DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT)
model.eval();

In [6]:
if args.output_dir:
    utils.mkdir(args.output_dir)

utils.init_distributed_mode(args)

device = torch.device(args.device)

dataset_test, num_classes = get_dataset(args, is_train=False)

test_sampler = torch.utils.data.SequentialSampler(dataset_test)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=8, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)

Not using distributed mode
loading annotations into memory...
Done (t=0.28s)
creating index...
index created!


In [7]:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
model.cuda()
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
print(confmat)

  return F.conv2d(input, weight, bias, self.stride,


Test:  [  0/625]  eta: 0:18:48    time: 1.8059  data: 0.7778  max mem: 1021
Test:  [100/625]  eta: 0:01:05    time: 0.1060  data: 0.0050  max mem: 2480
Test:  [200/625]  eta: 0:00:49    time: 0.1123  data: 0.0059  max mem: 2480
Test:  [300/625]  eta: 0:00:36    time: 0.1054  data: 0.0057  max mem: 2665
Test:  [400/625]  eta: 0:00:25    time: 0.1093  data: 0.0059  max mem: 2665
Test:  [500/625]  eta: 0:00:13    time: 0.1108  data: 0.0062  max mem: 2749
Test:  [600/625]  eta: 0:00:02    time: 0.1061  data: 0.0063  max mem: 3378
Test: Total time: 0:01:08
global correct: 91.4
average row correct: ['94.6', '84.3', '71.1', '72.8', '60.2', '49.3', '74.4', '61.5', '92.1', '35.9', '79.4', '58.7', '81.4', '80.3', '81.5', '88.0', '54.2', '87.6', '56.9', '84.7', '62.6']
IoU: ['90.4', '68.8', '56.4', '58.4', '45.8', '36.6', '67.6', '49.9', '76.7', '29.7', '64.2', '34.4', '62.7', '66.9', '68.8', '77.4', '29.4', '68.7', '46.3', '68.8', '52.3']
mean IoU: 58.1


  return torch.tensor(val)


### Заквантуем статические сетку, посмотрим на точность и скорость

In [8]:
# Квантуем
q_model = quantize_static(deepcopy(model), data_loader_test, num_batches=1)



In [9]:
# Замерим скорость квантованной модели на CPU
sample = next(iter(data_loader_test))
q_model.cpu()
with torch.no_grad():
    for _ in tqdm(range(2)):
        q_model(sample[0])

100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.63s/it]


In [10]:
# Замерим скорость оригинальной модели на CPU
sample = next(iter(data_loader_test))
model.cpu()
model = fuse_fx(model)
with torch.no_grad():
    for _ in tqdm(range(2)):
        model(sample[0])

100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.94s/it]


In [10]:
# Посчитаем метрики квантованной модели
q_model.cpu()
confmat = evaluate(q_model, data_loader_test, device='cpu', num_classes=num_classes)
print(confmat)

Test:  [   0/5000]  eta: 0:45:39    time: 0.5479  data: 0.3260  max mem: 222
Test:  [ 100/5000]  eta: 0:13:15    time: 0.1529  data: 0.0011  max mem: 222
Test:  [ 200/5000]  eta: 0:12:52    time: 0.1610  data: 0.0012  max mem: 222
Test:  [ 300/5000]  eta: 0:12:34    time: 0.1648  data: 0.0010  max mem: 222
Test:  [ 400/5000]  eta: 0:12:21    time: 0.1670  data: 0.0009  max mem: 222
Test:  [ 500/5000]  eta: 0:12:07    time: 0.1594  data: 0.0008  max mem: 222
Test:  [ 600/5000]  eta: 0:11:51    time: 0.1653  data: 0.0010  max mem: 222
Test:  [ 700/5000]  eta: 0:11:36    time: 0.1636  data: 0.0010  max mem: 222
Test:  [ 800/5000]  eta: 0:11:22    time: 0.1674  data: 0.0009  max mem: 222
Test:  [ 900/5000]  eta: 0:11:07    time: 0.1657  data: 0.0013  max mem: 222
Test:  [1000/5000]  eta: 0:10:51    time: 0.1724  data: 0.0009  max mem: 222
Test:  [1100/5000]  eta: 0:10:36    time: 0.1676  data: 0.0010  max mem: 222
Test:  [1200/5000]  eta: 0:10:20    time: 0.1670  data: 0.0011  max mem: 222

  return torch.tensor(val)


### Делаем Quantization Aware Training. Используем готовый трейнплуп от pytorch

In [7]:
# Делаем фейк квантование
qat_model = fake_quantization(model, data_loader_test)



### Тут берём из train.py скрипт main() и вытаскиваем трейн луп

1. Не забыть провалидировать модель fake quant до qat
2. Не забыть провалидировать модель после обучения
3. Конвертировать модель из fake quant в обычный quant
4. Проверить точность и скорость модели

In [7]:
def criterion(inputs, target):
    losses = {}
    for name, x in inputs.items():
        losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)

    if len(losses) == 1:
        return losses["out"]

    return losses["out"] + 0.5 * losses["aux"]

In [9]:
qat_model.cuda();

In [17]:
qat_model.cpu()
int_qat_model = convert_fx(qat_model)



In [18]:
# Замерим скорость квантованной модели на CPU
sample = next(iter(data_loader_test))
with torch.no_grad():
    for _ in tqdm(range(2)):
        int_qat_model(sample[0])

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:16<00:00,  8.41s/it]


In [14]:
qat_model.cuda();

In [15]:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
confmat = evaluate(qat_model, data_loader_test, device=device, num_classes=num_classes)
print(confmat)

  return torch.fused_moving_avg_obs_fake_quant(
  return torch.fused_moving_avg_obs_fake_quant(


Test:  [  0/625]  eta: 0:13:42    time: 1.3167  data: 0.9676  max mem: 3378
Test:  [100/625]  eta: 0:01:17    time: 0.1317  data: 0.0029  max mem: 3378
Test:  [200/625]  eta: 0:00:59    time: 0.1341  data: 0.0030  max mem: 3378
Test:  [300/625]  eta: 0:00:44    time: 0.1295  data: 0.0030  max mem: 3626
Test:  [400/625]  eta: 0:00:30    time: 0.1331  data: 0.0032  max mem: 3626
Test:  [500/625]  eta: 0:00:17    time: 0.1393  data: 0.0032  max mem: 3732
Test:  [600/625]  eta: 0:00:03    time: 0.1300  data: 0.0031  max mem: 4657
Test: Total time: 0:01:24
global correct: 88.7
average row correct: ['95.9', '58.5', '60.2', '51.7', '35.1', '19.9', '62.6', '41.4', '65.7', '13.7', '42.4', '42.5', '44.9', '46.7', '59.2', '71.8', '15.7', '50.1', '21.0', '70.9', '36.4']
IoU: ['87.9', '52.6', '49.2', '27.8', '30.1', '18.0', '57.6', '34.6', '55.6', '12.6', '36.0', '27.8', '35.1', '41.7', '53.4', '63.4', '12.0', '43.9', '19.4', '62.5', '24.3']
mean IoU: 40.3


  return torch.tensor(val)


In [11]:
args.lr = args.lr * 0.01

In [12]:
if args.output_dir:
    utils.mkdir(args.output_dir)

device = torch.device(args.device)

dataset, num_classes = get_dataset(args, is_train=True)
dataset_test, _ = get_dataset(args, is_train=False)

train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)

data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=args.batch_size,
    sampler=train_sampler,
    num_workers=args.workers,
    collate_fn=utils.collate_fn,
    drop_last=True,
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)

params_to_optimize = [
    {"params": [p for p in qat_model.backbone.parameters() if p.requires_grad]},
    {"params": [p for p in qat_model.classifier.parameters() if p.requires_grad]},
]
if args.aux_loss:
    params = [p for p in qat_model.aux_classifier.parameters() if p.requires_grad]
    params_to_optimize.append({"params": params, "lr": args.lr})
    
optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

scaler = torch.cuda.amp.GradScaler() if args.amp else None

iters_per_epoch = len(data_loader)
main_lr_scheduler = PolynomialLR(
    optimizer, total_iters=iters_per_epoch * (args.epochs - args.lr_warmup_epochs), power=0.9
)

if args.lr_warmup_epochs > 0:
    warmup_iters = iters_per_epoch * args.lr_warmup_epochs
    args.lr_warmup_method = args.lr_warmup_method.lower()
    if args.lr_warmup_method == "linear":
        warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
        )
    elif args.lr_warmup_method == "constant":
        warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
            optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
        )
    else:
        raise RuntimeError(
            f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
        )
    lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
        optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
    )
else:
    lr_scheduler = main_lr_scheduler

start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
    train_one_epoch(qat_model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler)
    confmat = evaluate(qat_model, data_loader_test, device=device, num_classes=num_classes)
    print(confmat)
    checkpoint = {
        "model": qat_model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "lr_scheduler": lr_scheduler.state_dict(),
        "epoch": epoch,
        "args": args,
    }
    if args.amp:
        checkpoint["scaler"] = scaler.state_dict()
    utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
    utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Training time {total_time_str}")

loading annotations into memory...
Done (t=8.80s)
creating index...
index created!
loading annotations into memory...
Done (t=0.23s)
creating index...
index created!


  return torch.fused_moving_avg_obs_fake_quant(
  return torch.fused_moving_avg_obs_fake_quant(
  return F.conv2d(input, weight, bias, self.stride,


Epoch: [0]  [   0/2891]  eta: 3:20:04  lr: 9.996886836501132e-05  loss: 0.9736 (0.9736)  time: 4.1523  data: 2.1492  max mem: 18302
Epoch: [0]  [  10/2891]  eta: 0:37:58  lr: 9.96574926992653e-05  loss: 0.9876 (0.9725)  time: 0.7909  data: 0.2028  max mem: 18331
Epoch: [0]  [  20/2891]  eta: 0:29:45  lr: 9.934600889796806e-05  loss: 0.9120 (0.9233)  time: 0.4453  data: 0.0055  max mem: 18331
Epoch: [0]  [  30/2891]  eta: 0:26:46  lr: 9.903441654658893e-05  loss: 0.8040 (0.8882)  time: 0.4354  data: 0.0032  max mem: 18331
Epoch: [0]  [  40/2891]  eta: 0:25:11  lr: 9.87227152275529e-05  loss: 0.8103 (0.8723)  time: 0.4342  data: 0.0033  max mem: 18331
Epoch: [0]  [  50/2891]  eta: 0:24:14  lr: 9.841090452020753e-05  loss: 0.8446 (0.8769)  time: 0.4351  data: 0.0034  max mem: 18331
Epoch: [0]  [  60/2891]  eta: 0:23:35  lr: 9.809898400078932e-05  loss: 0.8446 (0.8681)  time: 0.4381  data: 0.0037  max mem: 18331
Epoch: [0]  [  70/2891]  eta: 0:23:06  lr: 9.778695324238973e-05  loss: 0.7695

NameError: name 'confmat' is not defined

In [13]:
confmat = evaluate(qat_model, data_loader_test, device=device, num_classes=num_classes)
print(confmat)

Test:  [   0/5000]  eta: 0:43:49    time: 0.5259  data: 0.3771  max mem: 18331
Test:  [ 100/5000]  eta: 0:02:52    time: 0.0282  data: 0.0010  max mem: 18331
Test:  [ 200/5000]  eta: 0:02:24    time: 0.0245  data: 0.0009  max mem: 18331
Test:  [ 300/5000]  eta: 0:02:15    time: 0.0253  data: 0.0010  max mem: 18331
Test:  [ 400/5000]  eta: 0:02:10    time: 0.0303  data: 0.0011  max mem: 18331
Test:  [ 500/5000]  eta: 0:02:05    time: 0.0264  data: 0.0010  max mem: 18331
Test:  [ 600/5000]  eta: 0:02:01    time: 0.0260  data: 0.0010  max mem: 18331
Test:  [ 700/5000]  eta: 0:01:57    time: 0.0245  data: 0.0010  max mem: 18331
Test:  [ 800/5000]  eta: 0:01:53    time: 0.0256  data: 0.0010  max mem: 18331
Test:  [ 900/5000]  eta: 0:01:50    time: 0.0242  data: 0.0010  max mem: 18331
Test:  [1000/5000]  eta: 0:01:46    time: 0.0268  data: 0.0009  max mem: 18331
Test:  [1100/5000]  eta: 0:01:44    time: 0.0284  data: 0.0010  max mem: 18331
Test:  [1200/5000]  eta: 0:01:41    time: 0.0266  da

  return torch.tensor(val)


In [14]:
qat_model.cpu()
int_qat_model = convert_fx(deepcopy(qat_model))

In [39]:
fused_int_qat_model = fuse_fx(deepcopy(int_qat_model))

In [15]:
confmat = evaluate(int_qat_model, data_loader_test, device='cpu', num_classes=num_classes)
print(confmat)

Test:  [   0/5000]  eta: 0:52:53    time: 0.6346  data: 0.4461  max mem: 18331
Test:  [ 100/5000]  eta: 0:12:58    time: 0.1477  data: 0.0009  max mem: 18331
Test:  [ 200/5000]  eta: 0:12:29    time: 0.1544  data: 0.0009  max mem: 18331
Test:  [ 300/5000]  eta: 0:12:07    time: 0.1579  data: 0.0009  max mem: 18331
Test:  [ 400/5000]  eta: 0:11:53    time: 0.1583  data: 0.0009  max mem: 18331
Test:  [ 500/5000]  eta: 0:11:40    time: 0.1538  data: 0.0009  max mem: 18331
Test:  [ 600/5000]  eta: 0:11:25    time: 0.1590  data: 0.0009  max mem: 18331
Test:  [ 700/5000]  eta: 0:11:10    time: 0.1606  data: 0.0009  max mem: 18331
Test:  [ 800/5000]  eta: 0:10:55    time: 0.1624  data: 0.0010  max mem: 18331
Test:  [ 900/5000]  eta: 0:10:40    time: 0.1525  data: 0.0008  max mem: 18331
Test:  [1000/5000]  eta: 0:10:26    time: 0.1667  data: 0.0009  max mem: 18331
Test:  [1100/5000]  eta: 0:10:11    time: 0.1571  data: 0.0009  max mem: 18331
Test:  [1200/5000]  eta: 0:09:56    time: 0.1625  da

In [8]:
# Делаем фейк квантование
qat_model = fake_quantization(model, data_loader_test)



In [9]:
qat_model.cuda();

In [10]:
model.eval().cuda();

In [11]:
rmse_loss = torch.nn.MSELoss()

def criterion_distill(inputs, target):
    losses = {}
    for (name, x), (name_t, x_t) in zip(inputs.items(), target.items()):
        losses[name] = rmse_loss(x, x_t)

    if len(losses) == 1:
        return losses["out"]

    return losses["out"] + 0.5 * losses["aux"]

def train_one_epoch_distill(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None, t_model=None):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    header = f"Epoch: [{epoch}]"
    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        image, target = image.to(device), target.to(device)
        with torch.no_grad():
            t_output = t_model(image)
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            output = model(image)
            loss = criterion(output, target)
            d_loss = criterion_distill(output, t_output)
            loss = loss + d_loss
        
        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        lr_scheduler.step()

        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])

In [12]:
args.lr = args.lr * 0.01

In [13]:
if args.output_dir:
    utils.mkdir(args.output_dir)

device = torch.device(args.device)

dataset, num_classes = get_dataset(args, is_train=True)
dataset_test, _ = get_dataset(args, is_train=False)

train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)

data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=args.batch_size,
    sampler=train_sampler,
    num_workers=args.workers,
    collate_fn=utils.collate_fn,
    drop_last=True,
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)

params_to_optimize = [
    {"params": [p for p in qat_model.backbone.parameters() if p.requires_grad]},
    {"params": [p for p in qat_model.classifier.parameters() if p.requires_grad]},
]
if args.aux_loss:
    params = [p for p in qat_model.aux_classifier.parameters() if p.requires_grad]
    params_to_optimize.append({"params": params, "lr": args.lr})
    
optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

scaler = torch.cuda.amp.GradScaler() if args.amp else None

iters_per_epoch = len(data_loader)
main_lr_scheduler = PolynomialLR(
    optimizer, total_iters=iters_per_epoch * (args.epochs - args.lr_warmup_epochs), power=0.9
)

if args.lr_warmup_epochs > 0:
    warmup_iters = iters_per_epoch * args.lr_warmup_epochs
    args.lr_warmup_method = args.lr_warmup_method.lower()
    if args.lr_warmup_method == "linear":
        warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
        )
    elif args.lr_warmup_method == "constant":
        warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
            optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
        )
    else:
        raise RuntimeError(
            f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
        )
    lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
        optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
    )
else:
    lr_scheduler = main_lr_scheduler

start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
    train_one_epoch_distill(qat_model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler, model)
    confmat = evaluate(qat_model, data_loader_test, device=device, num_classes=num_classes)
    print(confmat)
    checkpoint = {
        "model": qat_model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "lr_scheduler": lr_scheduler.state_dict(),
        "epoch": epoch,
        "args": args,
    }
    if args.amp:
        checkpoint["scaler"] = scaler.state_dict()
    utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
    utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Training time {total_time_str}")

loading annotations into memory...
Done (t=7.54s)
creating index...
index created!
loading annotations into memory...
Done (t=0.25s)
creating index...
index created!


  return F.conv2d(input, weight, bias, self.stride,
  return torch.fused_moving_avg_obs_fake_quant(
  return torch.fused_moving_avg_obs_fake_quant(


Epoch: [0]  [   0/2891]  eta: 4:18:57  lr: 9.996886836501132e-05  loss: 4.6854 (4.6854)  time: 5.3745  data: 2.7601  max mem: 20709
Epoch: [0]  [  10/2891]  eta: 0:48:45  lr: 9.96574926992653e-05  loss: 4.8821 (5.1211)  time: 1.0156  data: 0.2551  max mem: 20740
Epoch: [0]  [  20/2891]  eta: 0:38:26  lr: 9.934600889796806e-05  loss: 4.8799 (4.8753)  time: 0.5749  data: 0.0039  max mem: 20740
Epoch: [0]  [  30/2891]  eta: 0:34:49  lr: 9.903441654658893e-05  loss: 4.4608 (4.7101)  time: 0.5732  data: 0.0035  max mem: 20740
Epoch: [0]  [  40/2891]  eta: 0:32:55  lr: 9.87227152275529e-05  loss: 3.9931 (4.4981)  time: 0.5769  data: 0.0036  max mem: 20740
Epoch: [0]  [  50/2891]  eta: 0:31:41  lr: 9.841090452020753e-05  loss: 3.9512 (4.4092)  time: 0.5750  data: 0.0038  max mem: 20740
Epoch: [0]  [  60/2891]  eta: 0:30:47  lr: 9.809898400078932e-05  loss: 3.9656 (4.2992)  time: 0.5698  data: 0.0041  max mem: 20740
Epoch: [0]  [  70/2891]  eta: 0:30:05  lr: 9.778695324238973e-05  loss: 3.6022

KeyboardInterrupt: 