### Делаем Quantization Aware Training. Используем готовый трейнплуп от pytorch

In [1]:
import datetime
import os
import pickle
import time
from copy import deepcopy
from pathlib import Path

import torch
import torch.utils.data
from torch import nn
from torch.ao.quantization.quantize_fx import convert_fx
from torch.ao.quantization.quantize_fx import fuse_fx
from torch.optim.lr_scheduler import PolynomialLR
from torchvision.models.segmentation import DeepLabV3_MobileNet_V3_Large_Weights, deeplabv3_mobilenet_v3_large
from tqdm import tqdm

import utils
from quantization_utils.fake_quantization import fake_quantization
from quantization_utils.static_quantization import quantize_static
from train import evaluate
from train import get_dataset
from train import train_one_epoch

  from .autonotebook import tqdm as notebook_tqdm


### Тут берём из train.py скрипт main() и вытаскиваем трейн луп

In [2]:
def criterion(inputs, target):
    losses = {}
    for name, x in inputs.items():
        losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)

    if len(losses) == 1:
        return losses["out"]

    return losses["out"] + 0.5 * losses["aux"]

def train_one_epoch(student_model, teacher_model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None):
    base_k = 0.5
    KD_k = 0.5

    student_model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    header = f"Epoch: [{epoch}]"
    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        image, target = image.to(device), target.to(device)
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            student_out = student_model(image)
            if teacher_model is not None:
                teacher_out = teacher_model(image)
                #KD with last layer logits
                KDloss1 = nn.functional.mse_loss(student_out['out'], teacher_out['out'])
                #Lets use also auxilary loss in KD
                KDloss2 = nn.functional.mse_loss(student_out['aux'], teacher_out['aux'])
                KDloss = KDloss1 + 0.5*KDloss2
            else:
                KDloss = 0
            loss = base_k*criterion(student_out, target) + KD_k*KDloss

        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        lr_scheduler.step()

        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])

In [3]:
def train(student_model, teacher_model, args):

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)

    device = torch.device(args.device)

    #torch.backends.cudnn.benchmark = False
    #torch.use_deterministic_algorithms(True)


    dataset, num_classes = get_dataset(args, is_train=True)
    dataset_test, _ = get_dataset(args, is_train=False)

    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        collate_fn=utils.collate_fn,
        drop_last=True,
    )

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, 
        batch_size=1, 
        sampler=test_sampler, 
        num_workers=args.workers, 
        collate_fn=utils.collate_fn
    )

    student_model.to(device)
    if teacher_model is not None:
        teacher_model.to(device)
        teacher_model.eval()

    model_without_ddp = student_model

    params_to_optimize = [
        {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
        {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
    ]
    if args.aux_loss:
        params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad]
        params_to_optimize.append({"params": params, "lr": args.lr * 10})
    
    optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    scaler = torch.cuda.amp.GradScaler() if args.amp else None

    iters_per_epoch = len(data_loader)
    main_lr_scheduler = PolynomialLR(
        optimizer, total_iters=iters_per_epoch * (args.epochs - args.lr_warmup_epochs), power=0.9
    )

    if args.lr_warmup_epochs > 0:
        warmup_iters = iters_per_epoch * args.lr_warmup_epochs
        args.lr_warmup_method = args.lr_warmup_method.lower()
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
        else:
            raise RuntimeError(
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
            )
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
        )
    else:
        lr_scheduler = main_lr_scheduler

    if args.resume:
        checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
        model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only)
        if not args.test_only:
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
            args.start_epoch = checkpoint["epoch"] + 1
            if args.amp:
                scaler.load_state_dict(checkpoint["scaler"])

    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(student_model, teacher_model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler)
        confmat = evaluate(student_model, data_loader_test, device=device, num_classes=num_classes)
        print(confmat)
        checkpoint = {
            "model": model_without_ddp.state_dict(),
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict(),
            "epoch": epoch,
            "args": args,
        }
        if args.amp:
            checkpoint["scaler"] = scaler.state_dict()
        utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
        utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")

In [4]:
import gc
gc.collect()
torch.cuda.empty_cache()

# Print current GPU memory usage
t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
f = r-a  # free inside reserved

print(f'Total: {t}, Reserved: {r}, Allocated: {a}, Free: {f}')

# Вытащил дефолтные аргументы, чтобы не упражняться с argparse в ноутбуке
with Path('./torch_default_args.pickle').open('rb') as file:
    args = pickle.load(file)

# Подобирайте под ваше железо
args.data_path = '/home/gvasserm/data/coco2017/'
args.epochs = 1
args.batch_size = 24
args.workers = 8

print(args)

model = deeplabv3_mobilenet_v3_large(weights=DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT)
model.eval()

if args.output_dir:
    utils.mkdir(args.output_dir)

utils.init_distributed_mode(args)

device = torch.device(args.device)

dataset_test, num_classes = get_dataset(args, is_train=False)

dataset_train, num_classes = get_dataset(args, is_train=True)

test_sampler = torch.utils.data.SequentialSampler(dataset_test)
train_sampler = torch.utils.data.SequentialSampler(dataset_train)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=16, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)

data_loader_train = torch.utils.data.DataLoader(
    dataset_train, batch_size=24, sampler=train_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

model.cuda()


qat_model = fake_quantization(model, data_loader_train)
qat_model.cuda()

train(qat_model, None, args)

Total: 16899571712, Reserved: 0, Allocated: 0, Free: 0
Namespace(data_path='/home/gvasserm/data/coco2017/', dataset='coco', model='deeplabv3_mobilenet_v3_large', aux_loss=False, device='cuda', batch_size=24, epochs=1, workers=8, lr=0.01, momentum=0.9, weight_decay=0.0001, lr_warmup_epochs=0, lr_warmup_method='linear', lr_warmup_decay=0.01, print_freq=10, output_dir='.', resume='', start_epoch=0, test_only=False, use_deterministic_algorithms=False, world_size=1, dist_url='env://', weights=None, weights_backbone=None, amp=False, backend='pil', use_v2=False)
Not using distributed mode
loading annotations into memory...
Done (t=0.28s)
creating index...
index created!
loading annotations into memory...
Done (t=7.55s)
creating index...
index created!




Not using distributed mode
loading annotations into memory...
Done (t=8.08s)
creating index...
index created!
loading annotations into memory...
Done (t=0.25s)
creating index...
index created!


  return torch.fused_moving_avg_obs_fake_quant(
  return torch.fused_moving_avg_obs_fake_quant(


Epoch: [0]  [   0/3854]  eta: 3:30:55  lr: 0.009997664733582535  loss: 0.3979 (0.3979)  time: 3.2838  data: 1.4601  max mem: 13843
Epoch: [0]  [  10/3854]  eta: 1:13:40  lr: 0.009974308733008223  loss: 0.4082 (0.3969)  time: 1.1499  data: 0.1387  max mem: 13869
Epoch: [0]  [  20/3854]  eta: 1:07:09  lr: 0.009950946654092182  loss: 0.4076 (0.4009)  time: 0.9394  data: 0.0066  max mem: 13869
Epoch: [0]  [  30/3854]  eta: 1:04:43  lr: 0.00992757847938837  loss: 0.3978 (0.4039)  time: 0.9417  data: 0.0067  max mem: 13869
Epoch: [0]  [  40/3854]  eta: 1:03:28  lr: 0.009904204191354918  loss: 0.3998 (0.4007)  time: 0.9436  data: 0.0068  max mem: 13869
Epoch: [0]  [  50/3854]  eta: 1:02:42  lr: 0.00988082377235332  loss: 0.3928 (0.4021)  time: 0.9478  data: 0.0068  max mem: 13869
Epoch: [0]  [  60/3854]  eta: 1:02:08  lr: 0.009857437204647676  loss: 0.4004 (0.4104)  time: 0.9504  data: 0.0069  max mem: 13869
Epoch: [0]  [  70/3854]  eta: 1:01:40  lr: 0.009834044470403858  loss: 0.4019 (0.4120

  return torch.tensor(val)


Training time 1:08:42


In [5]:
# Инференс делаем на cpu, предварительно конвертируя модельку на CPU
qat_model.cpu()
int_qat_model = convert_fx(qat_model)

In [6]:
# Точность модели fake quant и квантованной после конвертации будут разные
# Так и должно быть, всё таки мы эмулировали квантование.
data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=16, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)
int_qat_model.cpu()
confmat = evaluate(int_qat_model, data_loader_test, device='cpu', num_classes=num_classes)
print(confmat)

Test:  [  0/313]  eta: 0:36:51    time: 7.0646  data: 0.8907  max mem: 13869
Test:  [100/313]  eta: 0:18:17    time: 4.8674  data: 0.0012  max mem: 13869
Test:  [200/313]  eta: 0:09:41    time: 5.0602  data: 0.0013  max mem: 13869
Test:  [300/313]  eta: 0:01:07    time: 5.1654  data: 0.0012  max mem: 13869
Test: Total time: 0:26:56
global correct: 90.3
average row correct: ['94.8', '79.3', '62.8', '65.0', '49.5', '47.3', '70.5', '49.8', '91.2', '34.0', '75.8', '43.5', '75.0', '72.2', '74.7', '85.6', '43.2', '79.4', '46.6', '64.6', '59.0']
IoU: ['89.5', '53.6', '53.1', '40.7', '31.7', '28.0', '62.6', '41.5', '65.6', '26.3', '57.3', '28.7', '58.0', '61.4', '62.8', '73.8', '21.6', '56.1', '38.7', '57.8', '43.4']
mean IoU: 50.1
