# Week 9: Efficient model inference


### Seminar outline
1. Static PTQ
    - Toy example
    - MobileNetV2 on CIFAR10
    - QAT for MobileNetV2
    - Speed benchmark
2. 42 GB T5 to a single GPU showcase
    
## Static PTQ
### Toy example
[Source](https://pytorch.org/docs/stable/quantization.html) of the section

In [1]:
import copy
import os
import time
import warnings
from time import time

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.ao.quantization import DeQuantStub, QuantStub
from torch.utils.data import DataLoader
from tqdm.auto import trange

warnings.filterwarnings("ignore")
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")

In [2]:
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 5, 3)
        self.relu = torch.nn.ReLU()
        self.flatten = torch.nn.Flatten()
        self.linear = torch.nn.Linear(4500, 100)
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        start = time()
        x = self.conv(x)
        x = self.relu(x)
        x = self.linear(self.flatten(x))
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x


# create a model instance
model_fp32 = M()

# model must be set to eval mode for static quantization logic to work
model_fp32.eval()

# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'fbgemm' for server inference and
# 'qnnpack' for mobile inference. Other quantization configurations such
# as selecting symmetric or assymetric quantization and MinMax or L2Norm
# calibration techniques can be specified here.
model_fp32.qconfig = torch.quantization.get_default_qconfig("fbgemm")

# Fuse the activations to preceding layers, where applicable.
# This needs to be done manually depending on the model architecture.
# Common fusions include `conv + relu` and `conv + batchnorm + relu`
model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [["conv", "relu"]])

# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)

# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
input_fp32 = torch.randn(4, 1, 32, 32)
model_fp32_prepared(input_fp32)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
model_int8 = torch.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)

In [3]:
model_int8

M(
  (quant): Quantize(scale=tensor([0.0534]), zero_point=tensor([64]), dtype=torch.quint8)
  (conv): QuantizedConvReLU2d(1, 5, kernel_size=(3, 3), stride=(1, 1), scale=0.019489329308271408, zero_point=0)
  (relu): Identity()
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear): QuantizedLinear(in_features=4500, out_features=100, scale=0.011786960065364838, zero_point=61, qscheme=torch.per_channel_affine)
  (dequant): DeQuantize()
)

In [4]:
%%timeit
res = model_int8(input_fp32)

476 µs ± 6.54 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [5]:
%%timeit
res = model_fp32(input_fp32)

319 µs ± 2.54 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Why no speed up?

In [6]:
torch.save(model_int8.state_dict(), "test_model_q.pth")
torch.save(model_fp32.state_dict(), "test_model_full.pth")

In [7]:
!ls -al test_model_q.pth

-rw-rw-r-- 1 ubuntu ubuntu 456379 Mar 13 23:51 test_model_q.pth


In [8]:
!ls -al test_model_full.pth

-rw-rw-r-- 1 ubuntu ubuntu 1802207 Mar 13 23:51 test_model_full.pth


### MobileNetV2 on CIFAR10
[Source](https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html) of the section

In [9]:
def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class ConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        super().__init__(
            nn.Conv2d(
                in_planes,
                out_planes,
                kernel_size,
                stride,
                padding,
                groups=groups,
                bias=False,
            ),
            nn.BatchNorm2d(out_planes, momentum=0.1),
            nn.ReLU(inplace=False),
        )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super().__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = self.stride == 1 and inp == oup

        layers = []
        if expand_ratio != 1:
            # pw
            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
        layers.extend(
            [
                # dw
                ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup, momentum=0.1),
            ]
        )
        self.conv = nn.Sequential(*layers)
        # Replace torch.add with floatfunctional
        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, x):
        if self.use_res_connect:
            return self.skip_add.add(x, self.conv(x))
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(
        self,
        num_classes=1000,
        width_mult=1.0,
        inverted_residual_setting=None,
        round_nearest=8,
    ):
        """
        MobileNet V2 main class
        Args:
            num_classes (int): Number of classes
            width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
            inverted_residual_setting: Network structure
            round_nearest (int): Round the number of channels in each layer to be a multiple of this number
            Set to 1 to turn off rounding
        """
        super().__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280

        if inverted_residual_setting is None:
            inverted_residual_setting = [
                # t, c, n, s
                [1, 16, 1, 1],
                [6, 24, 2, 2],
                [6, 32, 3, 2],
                [6, 64, 4, 2],
                [6, 96, 3, 1],
                [6, 160, 3, 2],
                [6, 320, 1, 1],
            ]

        # only check the first element, assuming user knows t,c,n,s are required
        if (
            len(inverted_residual_setting) == 0
            or len(inverted_residual_setting[0]) != 4
        ):
            raise ValueError(
                "inverted_residual_setting should be non-empty "
                "or a 4-element list, got {}".format(inverted_residual_setting)
            )

        # building first layer
        input_channel = _make_divisible(input_channel * width_mult, round_nearest)
        self.last_channel = _make_divisible(
            last_channel * max(1.0, width_mult), round_nearest
        )
        features = [ConvBNReLU(3, input_channel, stride=2)]
        # building inverted residual blocks
        for t, c, n, s in inverted_residual_setting:
            output_channel = _make_divisible(c * width_mult, round_nearest)
            for i in range(n):
                stride = s if i == 0 else 1
                features.append(
                    block(input_channel, output_channel, stride, expand_ratio=t)
                )
                input_channel = output_channel
        # building last several layers
        features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
        # make it nn.Sequential
        self.features = nn.Sequential(*features)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, num_classes),
        )

        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.quant(x)
        x = self.features(x)
        x = x.mean([2, 3])
        x = self.classifier(x)
        x = self.dequant(x)
        return x

    # Fuse Conv+BN and Conv+BN+Relu modules prior to quantization
    # This operation does not change the numerics
    def fuse_model(self):
        for m in self.modules():
            if type(m) == ConvBNReLU:
                torch.ao.quantization.fuse_modules(m, ["0", "1", "2"], inplace=True)
            if type(m) == InvertedResidual:
                for idx in range(len(m.conv)):
                    if type(m.conv[idx]) == nn.Conv2d:
                        torch.ao.quantization.fuse_modules(
                            m.conv, [str(idx), str(idx + 1)], inplace=True
                        )

In [10]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=":f"):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def evaluate(model, criterion, data_loader, neval_batches, device=torch.device("cpu")):
    model.eval()
    model.to(device)
    top1 = AverageMeter("Acc@1", ":6.2f")
    top5 = AverageMeter("Acc@5", ":6.2f")
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            image, target = image.to(device), target.to(device)
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            print(".", end="")
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
            if cnt >= neval_batches:
                return top1, top5

    return top1, top5


def load_model(model_file):
    model = MobileNetV2()
    state_dict = torch.load(model_file)
    model.load_state_dict(state_dict)
    model.to("cpu")
    return model


def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print(f'Size (MB): {os.path.getsize("temp.p") / 1e6:.2f}')
    os.remove("temp.p")

In [11]:
def prepare_data_loaders():
    normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    dataset = torchvision.datasets.CIFAR10(
        root="./data",
        download=True,
        train=True,
        transform=transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        ),
    )
    dataset_test = torchvision.datasets.CIFAR10(
        root="./data",
        download=True,
        train=False,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                normalize,
            ]
        ),
    )

    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=train_batch_size, sampler=train_sampler, num_workers=16
    )

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=eval_batch_size, sampler=test_sampler
    )

    return data_loader, data_loader_test

In [12]:
# !wget https://download.pytorch.org/models/mobilenet_v2-b0353104.pth

In [13]:
saved_model_dir = "./"
float_model_file = "mobilenet_v2-b0353104.pth"
scripted_float_model_file = "mobilenet_quantization_scripted.pth"
scripted_quantized_model_file = "mobilenet_quantization_scripted_quantized.pth"

train_batch_size = 512
eval_batch_size = 64

In [14]:
data_loader, data_loader_test = prepare_data_loaders()
criterion = nn.CrossEntropyLoss()
float_model = load_model(saved_model_dir + float_model_file).to("cpu")

# Next, we'll "fuse modules"; this can both make the model faster by saving on memory access
# while also improving numerical accuracy. While this can be used with any model, this is
# especially common with quantized models.

print("\n Inverted Residual Block: Before fusion \n\n", float_model.features[1].conv)
float_model.eval()

# Fuses modules
float_model.fuse_model()

# Note fusion of Conv+BN+Relu and Conv+Relu
print("\n Inverted Residual Block: After fusion\n\n", float_model.features[1].conv)

Files already downloaded and verified
Files already downloaded and verified

 Inverted Residual Block: Before fusion 

 Sequential(
  (0): ConvBNReLU(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

 Inverted Residual Block: After fusion

 Sequential(
  (0): ConvBNReLU(
    (0): ConvReLU2d(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
      (1): ReLU()
    )
    (1): Identity()
    (2): Identity()
  )
  (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
  (2): Identity()
)


In [15]:
float_model.classifier = nn.Sequential(
    nn.Dropout(p=0.2), nn.Linear(in_features=1280, out_features=10)
)
criterion = nn.CrossEntropyLoss()
opt = torch.optim.Adam(params=float_model.parameters())

num_eval_batches = 1000

In [16]:
for epoch in trange(10):
    float_model.train()
    float_model.to(device)
    for x, y in data_loader:
        x, y = x.to(device), y.to(device)
        preds = float_model(x)
        loss = criterion(preds, y)
        opt.zero_grad()
        loss.backward()
        opt.step()
    float_model.eval()
    top1, top5 = evaluate(
        float_model,
        criterion,
        data_loader_test,
        neval_batches=num_eval_batches,
        device=device,
    )
    print(
        f"Evaluation accuracy on {(num_eval_batches * eval_batch_size)} images, {top1.avg:.2f}"
    )

  0%|          | 0/10 [00:00<?, ?it/s]

.............................................................................................................................................................Evaluation accuracy on 64000 images, 63.23
.............................................................................................................................................................Evaluation accuracy on 64000 images, 73.98
.............................................................................................................................................................Evaluation accuracy on 64000 images, 76.18
.............................................................................................................................................................Evaluation accuracy on 64000 images, 78.29
.............................................................................................................................................................Evaluation accuracy on 64000 images, 80.00


In [17]:
float_model.eval()
float_model.cpu()

print("Size of baseline model")
print_size_of_model(float_model)

top1, top5 = evaluate(
    float_model, criterion, data_loader_test, neval_batches=num_eval_batches
)
print(
    f"Evaluation accuracy on {(num_eval_batches * eval_batch_size, )} images, {top1.avg:.2f}"
)
torch.jit.save(
    torch.jit.script(float_model), saved_model_dir + scripted_float_model_file
)

Size of baseline model
Size (MB): 8.92
.............................................................................................................................................................Evaluation accuracy on (64000,) images, 81.29


Let's quantize the model!

Post-training static quantization involves not just converting the weights from float to int, as in dynamic quantization, but also performing the additional step of first feeding batches of data through the network and computing the resulting distributions of the different activations (specifically, this is done by inserting observer modules at different points that record this data). These distributions are then used to determine how the specifically the different activations should be quantized at inference time (a simple technique would be to simply divide the entire range of activations into 256 levels, but we support more sophisticated methods as well). Importantly, this additional step allows us to pass quantized values between operations instead of converting these values to floats — and then back to ints — between every operation, resulting in a significant speed-up.

In [18]:
num_calibration_batches = 512

q_model = copy.deepcopy(float_model)
# Specify quantization configuration
# Start with simple min/max range estimation and per-tensor quantization of weights
q_model.qconfig = torch.ao.quantization.default_qconfig
print(q_model.qconfig)
torch.ao.quantization.prepare(q_model, inplace=True)

# Calibrate first
print("Post Training Quantization Prepare: Inserting Observers")
print(
    "\n Inverted Residual Block:After observer insertion \n\n", q_model.features[1].conv
)

# Calibrate with the training set
evaluate(q_model, criterion, data_loader, neval_batches=num_calibration_batches)
print("Post Training Quantization: Calibration done")

# Convert to quantized model
torch.ao.quantization.convert(q_model, inplace=True)
print("Post Training Quantization: Convert done")
print(
    "\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n",
    q_model.features[1].conv,
)

print("Size of model after quantization")
print_size_of_model(q_model)

top1, top5 = evaluate(
    q_model, criterion, data_loader_test, neval_batches=num_eval_batches
)
print(
    f"Evaluation accuracy on {(num_eval_batches * eval_batch_size, )} images, {top1.avg:.2f}"
)

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})
Post Training Quantization Prepare: Inserting Observers

 Inverted Residual Block:After observer insertion 

 Sequential(
  (0): ConvBNReLU(
    (0): ConvReLU2d(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
      (1): ReLU()
      (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
    )
    (1): Identity()
    (2): Identity()
  )
  (1): Conv2d(
    32, 16, kernel_size=(1, 1), stride=(1, 1)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (2): Identity()
)
..................................................................................................Post Training Quantization: Calibration done
Post Training Quantization: Convert done


For this quantized model, we see lower accuracy on the eval dataset. This is because we used a simple min/max observer to determine quantization parameters. Nevertheless, we did reduce the size of our model down to just under 3.6 MB, almost a 4x decrease.

In addition, we can significantly improve on the accuracy simply by using a different quantization configuration. We repeat the same exercise with the recommended configuration for quantizing for x86 architectures. This configuration does the following:

Quantizes weights on a per-channel basis

Uses a histogram observer that collects a histogram of activations and then picks quantization parameters in an optimal manner.

In [19]:
per_channel_quantized_model = copy.deepcopy(float_model)
per_channel_quantized_model.qconfig = torch.ao.quantization.get_default_qconfig(
    "fbgemm"
)
print(per_channel_quantized_model.qconfig)

torch.ao.quantization.prepare(per_channel_quantized_model, inplace=True)
evaluate(per_channel_quantized_model, criterion, data_loader, num_calibration_batches)
torch.ao.quantization.convert(per_channel_quantized_model, inplace=True)
top1, top5 = evaluate(
    per_channel_quantized_model,
    criterion,
    data_loader_test,
    neval_batches=num_eval_batches,
)
print(
    f"Evaluation accuracy on {(num_eval_batches * eval_batch_size, )} images, {top1.avg:.2f}"
)
torch.jit.save(
    torch.jit.script(per_channel_quantized_model),
    saved_model_dir + scripted_quantized_model_file,
)

print(
    f"Evaluation accuracy on {(num_eval_batches * eval_batch_size, )} images, {top1.avg:.2f}"
)

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})
...............................................................................................................................................................................................................................................................Evaluation accuracy on (64000,) images, 80.19
Evaluation accuracy on (64000,) images, 80.19


### QAT for MobileNetV2
[Source](https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html) for the section

In [20]:
def train_one_epoch(
    model, criterion, optimizer, data_loader, device, ntrain_batches_log=200
):
    model.to(device)
    model.train()
    top1 = AverageMeter("Acc@1", ":6.2f")
    top5 = AverageMeter("Acc@5", ":6.2f")
    avgloss = AverageMeter("Loss", "1.5f")

    cnt = 0
    for image, target in data_loader:
        start_time = time()
        print(".", end="")
        cnt += 1
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        top1.update(acc1[0], image.size(0))
        top5.update(acc5[0], image.size(0))
        avgloss.update(loss, image.size(0))
        if cnt >= ntrain_batches_log:
            print("Loss", avgloss.avg)

            print(f"Training: * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}")

    print(f"Full train set:  * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}")
    return

In [21]:
qat_model = copy.deepcopy(float_model)
qat_model.train()
optimizer = torch.optim.SGD(qat_model.parameters(), lr=0.0001)
qat_model.qconfig = torch.ao.quantization.get_default_qat_qconfig("fbgemm")

In [22]:
# prepare_qat performs the “fake quantization”, preparing the model for quantization-aware training
torch.ao.quantization.prepare_qat(qat_model, inplace=True)
print(
    "Inverted Residual Block: After preparation for QAT, note fake-quantization modules \n",
    qat_model.features[1].conv,
)

Inverted Residual Block: After preparation for QAT, note fake-quantization modules 
 Sequential(
  (0): ConvBNReLU(
    (0): ConvReLU2d(
      32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32
      (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
        fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, reduce_range=False
        (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
      )
      (activation_post_process): FusedMovingAvgObsFakeQuantize(
        fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
        (activation_post_process): MovingAverageMinMaxObserver(min_val=inf

In [23]:
# QAT takes time and one needs to train over a few epochs.
# Train and check accuracy after each epoch
for nepoch in range(10):
    train_one_epoch(qat_model, criterion, optimizer, data_loader, device=device)
    if nepoch > 3:
        # Freeze quantizer parameters
        qat_model.apply(torch.ao.quantization.disable_observer)
    if nepoch > 2:
        # Freeze batch norm mean and variance estimates
        qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    # Check the accuracy after each epoch
    quantized_model = torch.ao.quantization.convert(
        qat_model.cpu().eval(), inplace=False
    )
    top1, top5 = evaluate(
        quantized_model, criterion, data_loader_test, neval_batches=num_eval_batches
    )
    print(
        f"Evaluation accuracy on {(num_eval_batches * eval_batch_size, )} images, {top1.avg:.2f}"
    )

..................................................................................................Full train set:  * Acc@1 88.008 Acc@5 99.502
.............................................................................................................................................................Evaluation accuracy on (64000,) images, 82.69
..................................................................................................Full train set:  * Acc@1 88.984 Acc@5 99.598
.............................................................................................................................................................Evaluation accuracy on (64000,) images, 82.98
..................................................................................................Full train set:  * Acc@1 89.328 Acc@5 99.622
.............................................................................................................................................................Evaluati

### Speed benchmark
Does it actually speed up something? Yep!

In [24]:
elapsed = 0
model = per_channel_quantized_model
model.eval()
num_batches = 100
# Run the scripted model on a few batches of images
for i, (images, target) in enumerate(data_loader_test):
    if i < num_batches:
        start = time()
        output = model(images)
        end = time()
        elapsed = elapsed + (end - start)
    else:
        break
num_images = images.size()[0] * num_batches

print(f"Elapsed time: {(elapsed / num_images * 1000)} ms")

Elapsed time: 0.30716948211193085 ms


In [25]:
elapsed = 0
model = float_model
model.eval()
num_batches = 100
# Run the scripted model on a few batches of images
for i, (images, target) in enumerate(data_loader_test):
    if i < num_batches:
        start = time()
        output = model(images)
        end = time()
        elapsed = elapsed + (end - start)
    else:
        break
num_images = images.size()[0] * num_batches

print(f"Elapsed time: {(elapsed / num_images * 1000)} ms")

Elapsed time: 0.3881186619400978 ms


## 45 GB T5 to a single GPU
[Source](https://huggingface.co/blog/hf-bitsandbytes-integration) of the section

In [26]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

In [27]:
model_name = "t5-3b-sharded"  # @param ["t5-11b-sharded", "t5-3b-sharded"]

# T5-3b and T5-11B are supported!
# We need sharded weights otherwise we get CPU OOM errors
model_id = f"ybelkada/{model_name}"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model_8bit = AutoModelForSeq2SeqLM.from_pretrained(
    model_id, device_map="auto", load_in_8bit=True
)


Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 7.0
CUDA SETUP: Detected CUDA version 111
CUDA SETUP: Loading binary /home/ubuntu/anaconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda111_nocublaslt.so...


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [28]:
model_8bit.get_memory_footprint() / 1e9

5.300543488

For t5-3b the int8 model is about ~5.3GB! whereas the original model has 11GB. For t5-11b the int8 model is about ~11GB vs 42GB for the original model. Now let's generate and see the qualitative results of the 8bit model!

In [29]:
max_new_tokens = 50

input_ids = tokenizer(
    "translate English to German: Hello my name is Younes and I am a Machine Learning Engineer at Hugging Face",
    return_tensors="pt",
).input_ids

outputs = model_8bit.generate(input_ids, max_new_tokens=max_new_tokens)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Hallo mein Name ist Younes und ich bin ein Ingenieur für Machine Learning bei Hugging Face
