<a href="https://colab.research.google.com/github/hideaki-kyutech/softcomp2024/blob/main/quantization2024.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 量子化: FP32 to BF16

## FP32でのFashionMNISTの分類モデルの作成

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

# Define transformation to normalize the data: convert images to tensors and normalize them to a range of -1 to 1
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if not available
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets
# Shuffle the training data to ensure randomness; validation data does not need shuffling
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

# Define class labels for easier interpretation of the results
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
           'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Define the model architecture for garment classification
class GarmentClassifier(nn.Module):
    def __init__(self):
        super(GarmentClassifier, self).__init__()
        # Convolutional layer: input channel 1 (grayscale), output 6, kernel size 5x5
        self.conv1 = nn.Conv2d(1, 6, 5)
        # Max pooling layer: kernel size 2x2, stride 2
        self.pool = nn.MaxPool2d(2, 2)
        # Second convolutional layer: input 6, output 16, kernel size 5x5
        self.conv2 = nn.Conv2d(6, 16, 5)
        # Fully connected layer: input size 16*4*4, output size 120
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        # Fully connected layer: output size 84
        self.fc2 = nn.Linear(120, 84)
        # Final fully connected layer: output size 10 (for 10 classes)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Apply first convolutional layer, ReLU activation, and pooling
        x = self.pool(F.relu(self.conv1(x)))
        # Apply second convolutional layer, ReLU activation, and pooling
        x = self.pool(F.relu(self.conv2(x)))
        # Flatten the tensor for fully connected layers
        x = x.view(-1, 16 * 4 * 4)
        # Apply fully connected layers with ReLU activation
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        # Apply the final layer to get logits for 10 classes
        x = self.fc3(x)
        return x

# Instantiate the model
model = GarmentClassifier()

# Define loss function (cross-entropy for multi-class classification)
loss_fn = torch.nn.CrossEntropyLoss()

# Define optimizer (Stochastic Gradient Descent)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Function to train the model for one epoch
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.0
    last_loss = 0.0

    # Enumerate over the training data
    for i, data in enumerate(training_loader):
        # Inputs and labels from the current batch
        inputs, labels = data

        # Zero out gradients to prevent accumulation
        optimizer.zero_grad()

        # Forward pass: make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and backpropagate
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Update the model weights
        optimizer.step()

        # Update running loss and log every 1000 batches
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000  # Average loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.0

    return last_loss

# Set up TensorBoard to visualize training metrics
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

# Number of epochs to train the model
EPOCHS = 5

# Initialize the best validation loss with a high value
best_vloss = 1_000_000.0

# Training loop
for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Set the model to training mode to track gradients
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)

    running_vloss = 0.0

    # Set the model to evaluation mode (disable dropout and batchnorm updates)
    model.eval()

    # Disable gradient computation for validation (saves memory and computation)
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            # Forward pass for validation data
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss.item()

    # Calculate average validation loss
    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log training vs. validation loss to TensorBoard
    writer.add_scalars('Training vs. Validation Loss',
                       {'Training': avg_loss, 'Validation': avg_vloss},
                       epoch_number + 1)
    writer.flush()

    # Save the model if the current validation loss is the best so far
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    # Increment the epoch counter
    epoch_number += 1


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26.4M/26.4M [00:01<00:00, 16.5MB/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29.5k/29.5k [00:00<00:00, 271kB/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4.42M/4.42M [00:00<00:00, 5.03MB/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5.15k/5.15k [00:00<00:00, 4.65MB/s]


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

EPOCH 1:
  batch 1000 loss: 1.7966186945587397
  batch 2000 loss: 0.8062484353855253
  batch 3000 loss: 0.6907540008034557
  batch 4000 loss: 0.6144923003916629
  batch 5000 loss: 0.5742344011918176
  batch 6000 loss: 0.5155273887147196
  batch 7000 loss: 0.5130124745445792
  batch 8000 loss: 0.48083408589911414
  batch 9000 loss: 0.4910981658545788
  batch 10000 loss: 0.4774242734978907
  batch 11000 loss: 0.43967159994679966
  batch 12000 loss: 0.42548703986563485
  batch 13000 loss: 0.4448238311446039
  batch 14000 loss: 0.3879323586110259
  batch 15000 loss: 0.40608691282640214
LOSS train 0.40608691282640214 valid 0.39423666497434023
EPOCH 2:
  batch 1000 loss: 0.41024087067562504
  batch 2000 loss: 0.35966621555329764
  batch 3000 loss: 0.39620297122129705
  batch 4000 loss: 0.3746347420369857
  batch 5000 loss: 0.3604794119874132
  batch 6000 loss: 0.3574897563600389
  batch 7000 loss: 0.3472

In [None]:
def print_param_dtype(model):
    # Iterate through all named parameters in the model
    for name, param in model.named_parameters():
        # Print the parameter name and its data type (dtype)
        print(f"{name} is loaded in {param.dtype}")

In [None]:
print_param_dtype(model)

conv1.weight is loaded in torch.float32
conv1.bias is loaded in torch.float32
conv2.weight is loaded in torch.float32
conv2.bias is loaded in torch.float32
fc1.weight is loaded in torch.float32
fc1.bias is loaded in torch.float32
fc2.weight is loaded in torch.float32
fc2.bias is loaded in torch.float32
fc3.weight is loaded in torch.float32
fc3.bias is loaded in torch.float32


## 学習済みモデルのBF16への移植

In [None]:
from copy import deepcopy

In [None]:
model_bf16 = deepcopy(model)

In [None]:
model_bf16 = model_bf16.to(torch.bfloat16)

In [None]:
print_param_dtype(model_bf16)

conv1.weight is loaded in torch.bfloat16
conv1.bias is loaded in torch.bfloat16
conv2.weight is loaded in torch.bfloat16
conv2.bias is loaded in torch.bfloat16
fc1.weight is loaded in torch.bfloat16
fc1.bias is loaded in torch.bfloat16
fc2.weight is loaded in torch.bfloat16
fc2.bias is loaded in torch.bfloat16
fc3.weight is loaded in torch.bfloat16
fc3.bias is loaded in torch.bfloat16


In [None]:
def evaluate_model(model, test_loader):
    # Set the model to evaluation mode
    model.eval()

    # Initialize metrics
    correct = 0
    total = 0

    # Disable gradient calculation for inference
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            # Forward pass: get predictions
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)  # Get class index with highest probability
            # Update metrics
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Calculate and return accuracy
    accuracy = 100 * correct / total
    return accuracy

In [None]:
# Load the test dataset
test_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=4, shuffle=False)

In [None]:
def evaluate_model_bf16(model, test_loader):
    # Set the model to evaluation mode
    model.eval()

    # Initialize metrics
    correct = 0
    total = 0

    # Disable gradient calculation for inference
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            # Convert input images to bfloat16 to match the model
            images = images.to(torch.bfloat16)
            # Forward pass: get predictions
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)  # Get class index with highest probability
            # Update metrics
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Calculate and return accuracy
    accuracy = 100 * correct / total
    return accuracy

## FP32モデルでの性能

In [None]:
test_accuracy = evaluate_model(model, test_loader)
print(f'Accuracy of the model on the test set 32 bits: {test_accuracy:.2f}%')

Accuracy of the model on the test set 32 bits: 88.96%


## BF16モデルでの性能

In [None]:
test_accuracy = evaluate_model_bf16(model_bf16, test_loader)
print(f'Accuracy of the model on the test set BF16: {test_accuracy:.2f}%')

Accuracy of the model on the test set BF16: 88.96%


In [None]:
torch.save(model.state_dict(), "garment_class_FP32.pth")
torch.save(model_bf16.state_dict(), "garment_class_bf16.pth")

## FP32モデルとBF16モデルのモデルサイズ確認(.pthファイルがモデルのファイル)

In [None]:
ls -al

total 1196
drwxr-xr-x 1 root root   4096 Feb  5 22:37 [0m[01;34m.[0m/
drwxr-xr-x 1 root root   4096 Feb  5 22:26 [01;34m..[0m/
drwxr-xr-x 4 root root   4096 Feb  4 14:22 [01;34m.config[0m/
drwxr-xr-x 3 root root   4096 Feb  5 22:28 [01;34mdata[0m/
-rw-r--r-- 1 root root  92498 Feb  5 22:37 garment_class_bf16.pth
-rw-r--r-- 1 root root 181394 Feb  5 22:37 garment_class_FP32.pth
-rw-r--r-- 1 root root 181528 Feb  5 22:29 model_20250205_222821_0
-rw-r--r-- 1 root root 181528 Feb  5 22:30 model_20250205_222821_1
-rw-r--r-- 1 root root 181528 Feb  5 22:32 model_20250205_222821_2
-rw-r--r-- 1 root root 181528 Feb  5 22:33 model_20250205_222821_3
-rw-r--r-- 1 root root 181528 Feb  5 22:34 model_20250205_222821_4
drwxr-xr-x 3 root root   4096 Feb  5 22:28 [01;34mruns[0m/
drwxr-xr-x 1 root root   4096 Feb  4 14:22 [01;34msample_data[0m/
