### Балуемся с дистилляцией
Врываемся в train.py и добавляем туда дистилляцию, просто по последнему слою (до софтмакса, на логитах) делаем стягивание по MSE

Цель поднять точность и ускорить сходимость.

Балуемся с весами обычного и distill лосса.

Можно вообще выкинуть classification loss и смоделировать ситуацию когда вам не выдали лейблов (жиза)

In [1]:
import datetime
import os
import pickle
import time
from copy import deepcopy
from pathlib import Path

import torch
import torch.utils.data
from torch import nn
from torch.ao.quantization.quantize_fx import convert_fx
from torch.ao.quantization.quantize_fx import fuse_fx
from torch.optim.lr_scheduler import PolynomialLR
from torchvision.models.segmentation import DeepLabV3_MobileNet_V3_Large_Weights, deeplabv3_mobilenet_v3_large
from tqdm import tqdm

import utils
from quantization_utils.fake_quantization import fake_quantization
from quantization_utils.static_quantization import quantize_static
from train import evaluate
from train import get_dataset
from train import train_one_epoch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def criterion(inputs, target):
    losses = {}
    for name, x in inputs.items():
        losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)

    if len(losses) == 1:
        return losses["out"]

    return losses["out"] + 0.5 * losses["aux"]

def train_one_epoch(student_model, teacher_model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None):
    base_k = 0.5
    KD_k = 0.5

    student_model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    header = f"Epoch: [{epoch}]"
    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        image, target = image.to(device), target.to(device)
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            student_out = student_model(image)
            if teacher_model is not None:
                teacher_out = teacher_model(image)
                #KD with last layer logits
                KDloss1 = nn.functional.mse_loss(student_out['out'], teacher_out['out'])
                #Lets use also auxilary loss in KD
                KDloss2 = nn.functional.mse_loss(student_out['aux'], teacher_out['aux'])
                KDloss = KDloss1 + 0.5*KDloss2
            else:
                KDloss = 0
            loss = base_k*criterion(student_out, target) + KD_k*KDloss

        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        lr_scheduler.step()

        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])

In [3]:
def train(student_model, teacher_model, args):

    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)

    device = torch.device(args.device)

    #torch.backends.cudnn.benchmark = False
    #torch.use_deterministic_algorithms(True)


    dataset, num_classes = get_dataset(args, is_train=True)
    dataset_test, _ = get_dataset(args, is_train=False)

    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        collate_fn=utils.collate_fn,
        drop_last=True,
    )

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, 
        batch_size=1, 
        sampler=test_sampler, 
        num_workers=args.workers, 
        collate_fn=utils.collate_fn
    )

    student_model.to(device)
    if teacher_model is not None:
        teacher_model.to(device)
        teacher_model.eval()

    model_without_ddp = student_model

    params_to_optimize = [
        {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
        {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
    ]
    if args.aux_loss:
        params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad]
        params_to_optimize.append({"params": params, "lr": args.lr * 10})
    
    optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    scaler = torch.cuda.amp.GradScaler() if args.amp else None

    iters_per_epoch = len(data_loader)
    main_lr_scheduler = PolynomialLR(
        optimizer, total_iters=iters_per_epoch * (args.epochs - args.lr_warmup_epochs), power=0.9
    )

    if args.lr_warmup_epochs > 0:
        warmup_iters = iters_per_epoch * args.lr_warmup_epochs
        args.lr_warmup_method = args.lr_warmup_method.lower()
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
        else:
            raise RuntimeError(
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
            )
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
        )
    else:
        lr_scheduler = main_lr_scheduler

    if args.resume:
        checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
        model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only)
        if not args.test_only:
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
            args.start_epoch = checkpoint["epoch"] + 1
            if args.amp:
                scaler.load_state_dict(checkpoint["scaler"])

    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(student_model, teacher_model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler)
        confmat = evaluate(student_model, data_loader_test, device=device, num_classes=num_classes)
        print(confmat)
        checkpoint = {
            "model": model_without_ddp.state_dict(),
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict(),
            "epoch": epoch,
            "args": args,
        }
        if args.amp:
            checkpoint["scaler"] = scaler.state_dict()
        utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
        utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")

In [5]:
import gc
gc.collect()
torch.cuda.empty_cache()

# Print current GPU memory usage
t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
f = r-a  # free inside reserved

print(f'Total: {t}, Reserved: {r}, Allocated: {a}, Free: {f}')

# Вытащил дефолтные аргументы, чтобы не упражняться с argparse в ноутбуке
with Path('./torch_default_args.pickle').open('rb') as file:
    args = pickle.load(file)

# Подобирайте под ваше железо
args.data_path = '/home/gvasserm/data/coco2017/'
args.epochs = 1
args.batch_size = 16
args.workers = 8

print(args)

model = deeplabv3_mobilenet_v3_large(weights=DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT)
model.eval()

if args.output_dir:
    utils.mkdir(args.output_dir)

utils.init_distributed_mode(args)

device = torch.device(args.device)

dataset_test, num_classes = get_dataset(args, is_train=False)

dataset_train, num_classes = get_dataset(args, is_train=True)

test_sampler = torch.utils.data.SequentialSampler(dataset_test)
train_sampler = torch.utils.data.SequentialSampler(dataset_train)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=24, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)

data_loader_train = torch.utils.data.DataLoader(
    dataset_train, batch_size=24, sampler=train_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

model.cuda()


qat_model = fake_quantization(model, data_loader_train)
qat_model.cuda()

train(qat_model, model, args)

# Инференс делаем на cpu, предварительно конвертируя модельку на CPU
qat_model.cpu()
int_qat_model = convert_fx(qat_model)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=12, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)
confmat = evaluate(int_qat_model, data_loader_test, device='cpu', num_classes=num_classes)
print(confmat)

Total: 16899571712, Reserved: 0, Allocated: 0, Free: 0
Namespace(data_path='/home/gvasserm/data/coco2017/', dataset='coco', model='deeplabv3_mobilenet_v3_large', aux_loss=False, device='cuda', batch_size=16, epochs=1, workers=8, lr=0.01, momentum=0.9, weight_decay=0.0001, lr_warmup_epochs=0, lr_warmup_method='linear', lr_warmup_decay=0.01, print_freq=10, output_dir='.', resume='', start_epoch=0, test_only=False, use_deterministic_algorithms=False, world_size=1, dist_url='env://', weights=None, weights_backbone=None, amp=False, backend='pil', use_v2=False)
Not using distributed mode
loading annotations into memory...
Done (t=0.27s)
creating index...
index created!
loading annotations into memory...
Done (t=7.98s)
creating index...
index created!




Not using distributed mode
loading annotations into memory...
Done (t=9.13s)
creating index...
index created!
loading annotations into memory...
Done (t=0.25s)
creating index...
index created!


  return torch.fused_moving_avg_obs_fake_quant(
  return torch.fused_moving_avg_obs_fake_quant(


Epoch: [0]  [   0/5782]  eta: 4:46:34  lr: 0.009998443431713478  loss: 2.6787 (2.6787)  time: 2.9739  data: 1.1832  max mem: 14888
Epoch: [0]  [  10/5782]  eta: 2:00:54  lr: 0.009982876267081917  loss: 2.6048 (2.6146)  time: 1.2568  data: 0.1119  max mem: 14987
Epoch: [0]  [  20/5782]  eta: 1:53:10  lr: 0.009967306404733907  loss: 2.4871 (2.5841)  time: 1.0888  data: 0.0048  max mem: 14987
Epoch: [0]  [  30/5782]  eta: 1:50:21  lr: 0.009951733839518005  loss: 2.5164 (2.5239)  time: 1.0932  data: 0.0051  max mem: 14987
Epoch: [0]  [  40/5782]  eta: 1:48:54  lr: 0.009936158566263944  loss: 2.2878 (2.4690)  time: 1.0954  data: 0.0053  max mem: 14987
Epoch: [0]  [  50/5782]  eta: 1:47:56  lr: 0.00992058057978255  loss: 2.2419 (2.4135)  time: 1.0966  data: 0.0053  max mem: 14987
Epoch: [0]  [  60/5782]  eta: 1:47:15  lr: 0.009904999874865638  loss: 2.0757 (2.3853)  time: 1.0976  data: 0.0053  max mem: 14987
Epoch: [0]  [  70/5782]  eta: 1:46:44  lr: 0.0098894164462859  loss: 2.0396 (2.3273)

  return torch.tensor(val)


Training time 1:50:52
Test:  [  0/417]  eta: 0:35:56    time: 5.1713  data: 1.0757  max mem: 14987
Test:  [100/417]  eta: 0:20:07    time: 3.7658  data: 0.0013  max mem: 14987
Test:  [200/417]  eta: 0:13:35    time: 3.4748  data: 0.0013  max mem: 14987
Test:  [300/417]  eta: 0:07:16    time: 3.4550  data: 0.0014  max mem: 14987
Test:  [400/417]  eta: 0:01:03    time: 3.5409  data: 0.0013  max mem: 14987
Test: Total time: 0:25:52
global correct: 90.8
average row correct: ['95.1', '76.2', '61.2', '65.0', '53.5', '40.5', '68.2', '51.6', '86.5', '29.5', '70.0', '51.5', '76.6', '76.4', '71.3', '84.6', '41.1', '80.6', '56.8', '83.7', '60.8']
IoU: ['89.9', '63.8', '52.7', '54.3', '41.8', '31.5', '63.1', '43.7', '72.0', '24.0', '60.2', '32.1', '54.1', '59.5', '63.4', '74.5', '25.9', '67.6', '42.8', '63.9', '50.2']
mean IoU: 53.8


    FP32    Static  QAT     QAT + KD

IOU     56.4    48.4        50.1        53.8