# L5-C: Unpacking 2-Bit Weights

In this lesson, you will learn how to "unpack" the stored low precision "packed" weights.

In [1]:
import torch

## Unpacking

**Note:** Younes will explain the below code, and walk through each iteration step. You can go through the comprehensive explaination in the markdown below after first watching Younes's explaination.

```Python
# Example Tensor: [10110001]
    # Which was Originally: 1 0 3 2 - 01 00 11 10

    # Starting point of unpacked Tensor
    # [00000000 00000000 00000000 00000000]
    
    ##### First Iteration Start:
    # packed int8 Tensor: [10110001]
    # You want to extract 01 from [101100 01]
    # No right shifts in the First Iteration
    # After bit-wise OR operation between 00000000 and 10110001:
    # [10110001 00000000 00000000 00000000]
    # unpacked Tensor state: [10110001 00000000 00000000 00000000]
    ##### First Iteration End

    ##### Second Iteration Start:
    # packed int8 Tensor: [10110001]
    # You want to extract 00 from [1011 00 01]
    # 2 right shifts:
    # [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100
    # After bit-wise OR operation between 00000000 and 00101100:
    # [10110001 00101100 00000000 00000000]
    # unpacked Tensor state: [10110001 00101100 00000000 00000000]
    ##### Second Iteration End

    ##### Third Iteration Start:
    # packed int8 Tensor: [10110001]
    # You want to extract 11 from [10 11 0001]
    # 4 right shifts:
    # [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100
    # 00101100 (3 shift)-> 00010110 (4 shift)-> 00001011
    # After bit-wise OR operation between 00000000 and 00001011:
    # [10110001 00101100 00001011 00000000]
    # unpacked Tensor state: [10110001 00101100 00001011 00000000]
    ##### Third Iteration End

    ##### Fourth Iteration Start:
    # packed int8 Tensor: [10110001]
    # You want to extract 10 from [10 110001]
    # 6 right shifts:
    # [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100
    # 00101100 (3 shift)-> 00010110 (4 shift)-> 00001011
    # 00001011 (5 shift)-> 00000101 (6 shift)-> 00000010
    # After bit-wise OR operation between 00000000 and 00000010:
    # [10110001 00101100 00001011 00000010]
    # unpacked Tensor state: [10110001 00101100 00001011 00000010]
    ##### Fourth Iteration End
    
    # Last step: Perform masking (bit-wise AND operation)
    # Mask: 00000011
    # Bit-wise AND operation between 
    # unpacked Tensor and 00000011
    # [10110001 00101100 00001011 00000010] <- unpacked tensor
    # [00000011 00000011 00000011 00000011] <- Mask
    # [00000001 00000000 00000011 00000010] <- Result

    # Final
    # unpacked Tensor state: [00000001 00000000 00000011 00000010]

```

In [5]:
def unpack_weights(uint8tensor, bits):
    num_values = uint8tensor.shape[0] * 8 // bits
    num_steps = 8 // bits
    unpacked_tensor = torch.zeros((num_values), dtype=torch.uint8)
    unpacked_idx = 0

    for i in range(uint8tensor.shape[0]):
        for j in range(num_steps):
            unpacked_tensor[unpacked_idx] |= uint8tensor[i] >> (bits * j)
            unpacked_idx += 1
    
    mask = 2 ** bits - 1
    unpacked_tensor &= mask

    return unpacked_tensor

In [6]:
unpacked_tensor = torch.tensor([177, 255], dtype=torch.uint8)

In [7]:
# Answer should be: torch.tensor([1, 0, 3, 2, 3, 3, 3, 3]
unpack_weights(unpacked_tensor, bits=2)

tensor([1, 0, 3, 2, 3, 3, 3, 3], dtype=torch.uint8)