# DTypes & Devices: Choose Your Weapons

**Module 1 | Lesson 3**

---

### Professor Torchenstein's Grand Directive

Mwahahaha! You've sliced, fused, and reshaped tensors with the skill of a master surgeon! You can command their *form*, but what of their *soul*? What of their very *essence*?

Today, we delve deeper! We shall master the two most fundamental properties of any tensor: its **data type (`dtype`)**, which determines its precision and power, and its **device**, the very dimension it inhabits—be it the humble CPU or the roaring, incandescent GPU! Choose your weapons wisely, for these choices dictate the speed, precision, and ultimate success of your grand experiments!

![Torchenstein holding a glowing cube](/assets/images/torchenstein_presenting_cube.png)

### Your Mission Briefing

By the end of this electrifying session, you will have mastered the arcane arts of:

*   **🔬 Identifying** a tensor's data type (`dtype`) and the computational device it resides on.
*   **✨ Transmuting** tensors between different `dtypes` to balance precision and performance.
*   **⚡ Teleporting** tensors between the CPU and GPU to accelerate your computations.
*   **⚠️ Dodging** the catastrophic errors that arise from mismatched `dtypes` and devices.

**Estimated Time to Completion:** 15 minutes of alchemical mastery.

**What You'll Need:**
*   The wisdom from our previous lessons on [tensor summoning](../01_introduction_to_tensors), [surgery](../02a_tensor_manipulation), and [metamorphosis](../02b_tensor_metamorphosis).
*   A PyTorch environment, preferably with a GPU waiting to be awakened!
*   A thirst for computational power!



### Previously in the Lab... (A Quick Recap)

In our last experiment, we mastered **Tensor Metamorphosis**, transforming tensor shapes with `reshape`, `view`, `squeeze`, and `unsqueeze`. We learned that a tensor's shape is merely an illusion—a view into a contiguous block of 1D memory.

Now that you command a tensor's external form, we shall master its internal essence. The journey continues!



## Part 2: The Alchemist's Arsenal - Mastering Data Types (`dtype`)

Behold, apprentice! Not all tensors are forged from the same ethereal stuff. The very *essence* of a tensor—its `.dtype`—determines what kind of numbers it can hold, its precision in the arcane arts of mathematics, and the amount of precious memory it consumes!

A wise choice of `dtype` can mean the difference between a lightning-fast model and a sluggish, memory-guzzling behemoth. Let us inspect the primary weapons in our arsenal!



### Checking the soul of your tensor

We will summon tensors of different `dtypes`, transmute them, and witness the performance implications firsthand!

To perform these miracles, you must master two key tools:
*   The `.dtype` attribute: A tensor's inherent property that reveals its data type. You can't change it directly, but you can inspect it to understand your tensor's essence.
*   The `.to()` method: This is your transmutation spell! It's a powerful and versatile method that not only changes a tensor's `dtype` but can also teleport it to a different `device` at the same time!



### Floating-Point Types (The Elixirs of Learning)
*The very lifeblood of neural networks! These are essential for representing real numbers, calculating gradients, and enabling your models to learn.*

-   `torch.float32` (`torch.float`): The 32-bit workhorse. This is the default `dtype` for a reason—it offers a fantastic balance between precision and performance. Most of your initial experiments will thrive on this reliable elixir.
-   `torch.float64` (`torch.double`): 64-bit, for when you require the utmost, surgical precision. Its use in deep learning is rare, as it doubles memory usage and can slow down computations, but for certain scientific calculations, it is indispensable. A powerful tool, but often overkill for our purposes!
-   `torch.float16` (`torch.half`): A 16-bit potion for speed and memory efficiency. Halving the precision can dramatically accelerate training on modern GPUs and cut your memory footprint in half! But beware—its limited range can sometimes lead to numerical instability.
-   `torch.bfloat16`: The new favorite in the high council of AI! Also 16-bit, but with a crucial difference from `float16`. It sacrifices some precision to maintain the same dynamic range as `float32`, making it far more stable for training large models like Transformers.

In [13]:
# A single number for our comparison
pi_number = 3.14159265

# Summoning tensors of different float dtypes
tensor_fp64 = torch.tensor(pi_number, dtype=torch.float64)
tensor_fp32 = torch.tensor(pi_number, dtype=torch.float32)
tensor_fp16 = torch.tensor(pi_number, dtype=torch.float16)
tensor_bf16 = torch.tensor(pi_number, dtype=torch.bfloat16)

print("--- Floating-Point Memory Footprints ---")
print(f"{tensor_fp64.dtype}: {tensor_fp64.item():.8f} | Memory: {tensor_fp64.element_size()} bytes")
print(f"{tensor_fp32.dtype}: {tensor_fp32.item():.8f} | Memory: {tensor_fp32.element_size()} bytes")
print(f"{tensor_fp16.dtype}: {tensor_fp16.item():.8f} | Memory: {tensor_fp16.element_size()} bytes (Half the size of fp32!)")
print(f"{tensor_bf16.dtype}: {tensor_bf16.item():.8f} | Memory: {tensor_bf16.element_size()} bytes (Half the size of fp32!)")

# A large number for our comparison

large_number = 70000.0
tensor_fp64 = torch.tensor(large_number, dtype=torch.float64)
tensor_fp32 = torch.tensor(large_number, dtype=torch.float32)
tensor_fp16 = torch.tensor(large_number, dtype=torch.float16)
tensor_bf16 = torch.tensor(large_number, dtype=torch.bfloat16)

print("--- Floating-Point Memory Footprints ---")
print(f"{tensor_fp64.dtype}: {tensor_fp64.item():.8f} ")
print(f"{tensor_fp32.dtype}: {tensor_fp32.item():.8f} ")
print(f"{tensor_fp16.dtype}: {tensor_fp16.item():.8f} ")
print(f"{tensor_bf16.dtype}: {tensor_bf16.item():.8f} ")

--- Floating-Point Memory Footprints ---
torch.float64: 3.14159265 | Memory: 8 bytes
torch.float32: 3.14159274 | Memory: 4 bytes
torch.float16: 3.14062500 | Memory: 2 bytes (Half the size of fp32!)
torch.bfloat16: 3.14062500 | Memory: 2 bytes (Half the size of fp32!)
--- Floating-Point Memory Footprints ---
torch.float64: 70000.00000000 
torch.float32: 70000.00000000 
torch.float16: inf 
torch.bfloat16: 70144.00000000 



#### The Curious Case of `bfloat16` and the Number 70,144

Mwahahaha! Apprentice, you have sharp eyes! You witnessed a strange transmutation: our number `70000.0` became `70144.0` when cast to `bfloat16`. Is this a bug? A flaw in our alchemy? No! This is a profound secret about the very fabric of digital reality!

To understand this, we must journey into the heart of the machine and see how it stores floating-point numbers.

##### The Blueprint of a Float: Scientific Notation in Binary

Every floating-point number in your computer's memory is stored like a secret formula with three parts:

1.  **The Sign (S)**: A single bit (0 for positive, 1 for negative).
2.  **The Exponent (E)**: A set of bits that represent the number's magnitude or *range*, like the `10^x` part in scientific notation.
3.  **The Mantissa (M)**: A set of bits that represent the actual digits of the number—its *precision*.

The number is roughly reconstructed as: `(-1)^S * M * 2^E`.

The mantissa is the key here. It's a binary fraction that always starts with an implicit `1.`, followed by the sum of fractional powers of 2. For example: `1.M = 1 + m1/2 + m2/4 + m3/8 + ...m23/2^23`, where `m1, m2, m3, ...m23` are 0 or 1 (depends on how many bits are in the mantissa, this case 23 bits).

##### The Meaning of "Precision"

When we say `bfloat16` has "less precision" than `float32`, we don't mean fewer decimal places in the way humans think. We mean it has **fewer bits in its mantissa**.

-   `float32` has 1 bit for sign, 23 mantissa bits and 8 bits for exponent
-   `float16` has 1 bit for sign, 10 mantissa bits and 5 bits for exponent,  more bits for mantissa means less coarse==more precision, less range (min - max)
-   `bfloat16` has 1 bit for sign, 7 mantissa bits and 8 bits for exponent,  less bits for mantissa means more coarse==less precision, more range (min - max) then float16, same range as float32

This means `bfloat16` can only represent a much smaller, **coarser set of numbers between any two powers of two**. For small numbers (like 3.14), the representable values are very close together. But for large numbers, the "gaps" between representable values become huge!

##### Detailed explanation: Why 70,144?

The number `70000` is simply not one of the numbers that can be perfectly formed with `bfloat16`'s limited 7-bit mantissa at that large exponent range.


Lets write the number `70,000` in binary: `1 0001 0001 0111 0000`.

For `70,000`, scientific notation number starts with `1.`, we move the decimal point 16 places to the left (`2^16`).
$$ 1. \underbrace{0001000101110000}_{\text{16 binary digits}} \times 2^{16} $$

The **mantissa** (the part after the `1.`) is where the precision limit strikes.

A **`float32`** has **23 bits** for its mantissa. It can easily store those binary digits with room to spare. The number `70,000` is stored perfectly.
* `0`-sign bit, `0001000101110000 0000000`-23 mantissa bits, `00010000`- exponent bits (`2^16`, omits bias for simplicity)


A **`bfloat16`** only has **7 bits** for its mantissa.
* `0`-sign bit,`1.` `0001000`-7 mantissa bits (first 7 digits) -> rounded up to `0001001`, `00010000`- exponent bits (`2^16`, omits bias for simplicity)
* `1*2^16 + (0*1/2 + 0*1/4 + 0*1/8 + 1*1/16 + 0*1/32 + 0*1/64 + 1*1/128)*2^16 = 65536+ 4096 + 512=70144`

`bfloat16` must take that 16-digit binary sequence and round it to fit into just 7 bits. This forces a loss of information, even for a whole number.

1.  **Original Mantissa:** `0001000101110000`
2.  **`bfloat16` capacity:** Can only store the first 7 digits: `0001000`.
3.  **Rounding:** It checks the 8th digit (`1`) and, following rounding rules, rounds the 7-bit number up. The new mantissa becomes `0001001`.

So, `bfloat16` ends up storing the number as `1.0001001 \times 2^{16}`.


Think of it like trying to measure `70,000` millimeters with a ruler that only has markings every `256` millimeters. You can't land on `70,000` exactly. You must choose the closest mark.

The two closest "marks" that `bfloat16` can represent in that range are:
-   `69,888`
-   `70,144`

Since `70,000` is closer to `70,144`, the transmutation spell **rounds** it to that value. It is not an error, but the result of sacrificing precision to maintain the vast numerical range of `float32`. This robustness is exactly why it is the preferred elixir for training colossal neural networks! You have witnessed the fundamental trade-off of modern AI hardware!



### Why `bfloat16` is Better for Transformers: A Tale of Range and Rebellion

Mwahahaha! Now for a secret that separates the masters from the mere dabblers! Both `float16` and `bfloat16` use 16 bits, but they do so with diabolically different strategies. Understanding this is key to training modern marvels like Transformers!

The world of `float32` is a stable, predictable realm. But it is slow and memory-hungry! When we attempt to accelerate our dark arts with `float16`, we encounter a terrible problem: **The Tyranny of a Tiny Range**.

#### The Peril of `float16`: An Unstable Concoction

`float16` dedicates more bits to its mantissa (precision), but starves its exponent (range). Its numerical world is small, spanning from roughly `6.1 x 10^-5` to `65,504`. Anything outside this narrow window becomes an `inf` (infinity) or vanishes to zero.

During the chaotic process of training a massive Transformer, values can fluctuate wildly. This is where the tyranny of `float16` strikes hardest:

*   **Exploding Gradients**: Imagine a scenario deep within your network where a series of large gradients are multiplied. Even with normalization, an intermediate calculation can easily exceed 65,504. For instance, the Adam optimizer tracks the variance of gradients (`v` term), which can grow very large. If this value overflows to `inf`, the weight update becomes `NaN` (Not a Number), and your entire training process collapses into a fiery numerical singularity!
*   **Vanishing Activations**: Inside a Transformer, attention scores are passed through a Softmax function. If the input values (logits) are very large negative numbers, the resulting probabilities can become smaller than `float16`'s minimum representable value. They are rounded down to zero, and that part of your model stops learning entirely!

To combat this, alchemists of old used a crude technique called **loss scaling**: manually multiplying the loss to keep gradients within `float16`'s safe range. It is a messy, unreliable hack!

#### The `bfloat16` Rebellion: Sacrificing Precision for Power!

The great minds at Google Brain, in their quest for ultimate power, forged a new weapon: the **Brain Floating-Point Format**, or `bfloat16`! They looked at the chaos of `float16` and made a brilliant, rebellious choice.

They designed `bfloat16` to have the **same number of exponent bits as `float32`** (8 bits). This gives it the exact same colossal dynamic range, spanning from `1.18 x 10^-38` to `3.4 x 10^38`. It can represent gargantuan numbers and infinitesimally small ones without breaking a sweat.

The price? It has fewer mantissa bits (7 bits) than `float16` (10 bits), giving it less precision. But here is the profound secret, backed by countless experiments in the deepest labs: **neural networks are incredibly resilient to low precision.**

Why do the inaccuracies not hurt?
*   **Stochastic Nature of Training**: We train models using stochastic gradient descent on mini-batches of data. This process is inherently noisy! The tiny inaccuracies introduced by `bfloat16`'s rounding are like a single drop of rain in a hurricane—they are statistically insignificant compared to the noise already present in the training process.
*   **Error Accumulation is Not Catastrophic**: As researchers from [The Hardware Lottery](https://arxiv.org/abs/2009.06489) blog and other deep learning practitioners have noted, the errors from low precision tend to average out over millions of updates. The network's learning direction isn't meaningfully altered. The gradient still points downhill, even if it's a slightly wobblier path.

> "For the volatile, chaotic world of deep learning, a vast and stable **range** is far more important than surgical **precision**."
>
> — **_Prof. Torchenstein_**

#### The Transformer's Elixir of Choice

For training and fine-tuning, `bfloat16` is the undisputed champion, the elixir that fuels the titans of AI.

1.  **Training Stability**: Its `float32`-like range means no more exploding gradients in optimizer states or vanishing activations in softmax. You can throw away the clumsy crutch of loss scaling.
2.  **Memory Efficiency**: Like `float16`, it cuts your model's memory footprint in half compared to `float32`. This allows you to train larger models or use larger batch sizes, accelerating your path to discovery.
3.  **Hardware Acceleration**: It is natively supported on the most powerful instruments in any modern laboratory: Google TPUs and NVIDIA's latest GPUs (Ampere architecture and newer, like the A100 or RTX 30/40 series).

**The Rogues' Gallery: Who Uses `bfloat16`?**
The most powerful creations of our time were forged in the fires of `bfloat16`. Giants like **Google's T5 and BERT**, **Meta's Llama 2**, the **Falcon** models, and many more rely on `bfloat16` for stable and efficient training. 

**The Verdict for Your Lab:**
*   **For Training & Fine-Tuning**: `bfloat16` is your weapon of choice. It is the modern standard for a reason.
*   **For Inference**: `float16` is often perfectly acceptable. After a model is trained, the range of values it processes is more predictable, making `float16`'s higher precision and wider hardware support a safe and efficient option.


In [None]:
# --- Range vs. Precision Demonstration ---
print("--- A Tale of Two Half-Precisions ---")

# A number just beyond the float16 limit to demonstrate range
high_number = torch.tensor(70000.0, dtype=torch.float32) 
# A high-precision number to demonstrate precision
precise_number = torch.tensor(3.14159265, dtype=torch.float32)

# --- The Range Test ---
print(f"\\n--- The Range Test (Number: {high_number.item()}) ---")
print(f"Original (Float32):  {high_number.item()}")

bf16_high = high_number.to(torch.bfloat16)
print(f"BFloat16: {bf16_high.item()} (Handles the large number flawlessly!)")

fp16_high = high_number.to(torch.float16)
print(f"Float16:  {fp16_high.item()} (Overflows to infinity! A tragic failure!)")






--- A Tale of Two Half-Precisions ---
\n--- The Range Test (Number: 70000.0) ---
Original (Float32):  70000.0
BFloat16: 70144.0 (Handles the large number flawlessly!)
Float16:  inf (Overflows to infinity! A tragic failure!)
\n--- The Precision Test (Number: 3.1415927410125732) ---
Original (Float32):  3.14159274
Float16:  3.14062500 (More precise, closer to the original!)
BFloat16: 3.14062500 (Less precise, a worthy sacrifice for range!)


### Integer Types (The Counting Stones)
*For when you need to count, index, or represent discrete information like image pixel values or class labels.*

-   `torch.int64` (`torch.long`): The 64-bit grandmaster of integers. This is the default for indexing operations and is crucial for embedding layers, where you need to look up values from a large vocabulary.
-   `torch.int32` (`torch.int`): A solid 32-bit integer, perfectly suitable for most counting tasks.
-   `torch.uint8`: An 8-bit unsigned integer, representing values from 0 to 255. The undisputed king for storing image data, where each pixel in an RGB channel has a value in this exact range!



Now for the counting stones—the integer types. Their purpose is not precision, but to hold whole numbers. Observe their varying sizes.



In [8]:
# Summoning tensors of different integer dtypes
tensor_i64 = torch.tensor(1000, dtype=torch.int64)
tensor_i32 = torch.tensor(1000, dtype=torch.int32)
tensor_i16 = torch.tensor(1000, dtype=torch.int16)
tensor_i8 = torch.tensor(100, dtype=torch.int8)
tensor_ui8 = torch.tensor(255, dtype=torch.uint8)

print("--- Integer Memory Footprints ---")
print(f"torch.int64: Memory: {tensor_i64.element_size()} bytes")
print(f"torch.int32: Memory: {tensor_i32.element_size()} bytes")
print(f"torch.int16: Memory: {tensor_i16.element_size()} bytes")
print(f"torch.int8:  Memory: {tensor_i8.element_size()} bytes")
print(f"torch.uint8: Memory: {tensor_ui8.element_size()} bytes")



--- Integer Memory Footprints ---
torch.int64: Memory: 8 bytes
torch.int32: Memory: 4 bytes
torch.int16: Memory: 2 bytes
torch.int8:  Memory: 1 bytes
torch.uint8: Memory: 1 bytes


### The Transmutation Spell: Witnessing the Effects of Casting

Now that you understand the properties of each `dtype`, witness what happens when we perform the transmutation! Casting from a higher precision `dtype` to a lower one is a **lossy** operation. You gain speed and save memory, but at the cost of precision!

Observe the fate of our high-precision number as we cast it down the alchemical ladder.



In [10]:
# Our original, high-precision tensor
pi_fp64 = torch.tensor(3.141592653589793, dtype=torch.float64)
print(f"Original (float64): {pi_fp64.item():.15f}")

# Cast it down
pi_fp32 = pi_fp64.to(torch.float32)
print(f"Casted to float32:  {pi_fp32.item():.15f} (Precision lost!)")

pi_fp16 = pi_fp64.to(torch.float16)
print(f"Casted to float16:  {pi_fp16.item():.15f} (More precision lost!)")

pi_bf16 = pi_fp64.to(torch.bfloat16)
print(f"Casted to bfloat16: {pi_bf16.item():.15f} (Less precision, but still good!)")

# Casting floats to integers truncates the decimal part entirely!
integer_pi = pi_fp64.to(torch.int)
print(f"\\nCasted to integer: {integer_pi.item()} (Decimal part vanished!)")



Original (float64): 3.141592653589793
Casted to float32:  3.141592741012573 (Precision lost!)
Casted to float16:  3.140625000000000 (More precision lost!)
Casted to bfloat16: 3.140625000000000 (Less precision, but still good!)
\nCasted to integer: 3 (Decimal part vanished!)


Let's witness this cosmic trade-off with a simple experiment. We will test `bfloat16`'s vast **range** against `float16`'s superior **precision**.



### Boolean Type (The Oracle)
*Represents the fundamental truths of the universe: `True` or `False`.*

-   `torch.bool`: The result of all your logical incantations (`>`, `<`, `==`). Essential for creating masks to filter and select elements from your tensors.

---

## Part 3: The Lair of Computation - Mastering Devices (`device`)

### 1. A Tour of Computational Realms

A tensor's `dtype` is its soul, but its `.device` is its home—the very dimension where its calculations will be performed. Choosing the right device is the key to unlocking diabolical computational speed!

-   **`cpu`**: The Central Processing Unit. The reliable, ever-present brain of your machine. It's a generalist, capable of any task, but it performs calculations sequentially. For small tensors and simple operations, it's perfectly adequate.
-   **`cuda`**: The NVIDIA GPU! This is the roaring heart of the deep learning revolution. A GPU is a specialist, containing thousands of cores designed for one purpose: massively parallel computation. Moving your tensors and models here is **essential** for training any serious neural network.
-   **`mps`**: Metal Performance Shaders. Apple's answer to CUDA for their new M-series chips. If you are wielding a modern Mac, this device will unleash the power of its integrated GPU.
-   

#### How PyTorch Supports So Many Devices: The Magic of ATen

How can a single command like `torch.matmul()` work on a CPU, an NVIDIA GPU, and an Apple chip? The secret lies in PyTorch's core library: **ATen**.

Think of ATen as a grand dispatcher in our laboratory. When you issue a command, ATen inspects the tensor's `.device` and redirects the command to a highly optimized, device-specific library:
-   If `device='cpu'`, ATen calls libraries like `oneDNN`.
-   If `device='cuda'`, ATen calls NVIDIA's legendary `cuDNN` library.
-   If `device='mps'`, ATen calls Apple's `Metal` framework.

This brilliant design makes your PyTorch code incredibly portable. You write the incantation once, and ATen ensures it is executed with maximum power on whatever hardware you possess!



### 2. The Ritual: Dynamic Device Placement

A true PyTorch master does not hardcode their device! That is the way of the amateur. We shall write a glorious, platform-agnostic spell that automatically detects and selects the most powerful computational device available.

The hierarchy is clear: `CUDA` is the sanctum sanctorum, `MPS` is the respected wizard's tower, and `CPU` is our reliable home laboratory. Our code shall seek the most powerful realm first.



In [None]:
# Our Grand Spell for Selecting the Best Device
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Mwahahaha! We have awakened the CUDA beast!")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Behold! The power of Apple's Metal Performance Shaders!")
else:
    device = torch.device("cpu")
    print("The humble CPU will have to suffice for today's experiments.")

print(f"\\nSelected device: {device}\\n")

# --- Summoning and Teleporting Tensors ---

# 1. Summon a tensor directly on the chosen device
tensor_on_device = torch.randn(2, 3, device=device)
print(f"Tensor summoned directly on '{tensor_on_device.device}'")
print(tensor_on_device)

# 2. Teleport a CPU tensor to the device using the .to() spell
cpu_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(f"\\nA CPU tensor, minding its own business: {cpu_tensor.device}")

teleported_tensor = cpu_tensor.to(device)
print(f"Teleported to '{teleported_tensor.device}'!")
print(teleported_tensor)

# IMPORTANT: Operations between tensors on different devices will FAIL!
# This would cause a RuntimeError:
# try:
#     result = cpu_tensor + teleported_tensor
# except RuntimeError as e:
#     print(f"\\nAs expected, chaos ensues: {e}")




# Part 4: Practical speed trials

Testing the bfloat vs float16 and quantized operations 


### Speed Trials: Precision vs. Performance

Now for a truly electrifying experiment! We shall create colossal tensors of different floating-point `dtypes` and subject them to a barrage of mathematical operations. Witness the dramatic speed-up that 16-bit formats provide!



In [None]:
# A utility for our speed trials
def time_operation(tensor):
    start_time = time.time()
    # A sequence of intense mathematical transformations!
    torch.exp(torch.sin(tensor))
    end_time = time.time()
    return end_time - start_time

# A colossal tensor for our experiment!
large_tensor_cpu = torch.randn(2000, 2000)

# --- SPEED TRIALS ON CPU ---
print("--- CPU Speed Trials ---")
time_fp32_cpu = time_operation(large_tensor_cpu.clone())
print(f"Float32 on CPU took: {time_fp32_cpu:.6f} seconds")

# Float16 on CPU can be slower as it's not natively supported for computation
time_fp16_cpu = time_operation(large_tensor_cpu.clone().to(torch.float16))
print(f"Float16 on CPU took: {time_fp16_cpu:.6f} seconds (Often slower on CPU!)")

# BFloat16 on CPU
time_bf16_cpu = time_operation(large_tensor_cpu.clone().to(torch.bfloat16))
print(f"BFloat16 on CPU took: {time_bf16_cpu:.6f} seconds\\n")


# --- SPEED TRIALS ON GPU (if available) ---
if torch.cuda.is_available():
    print("--- GPU Speed Trials ---")
    large_tensor_gpu = large_tensor_cpu.to("cuda")

    # Make sure operations are complete before stopping the timer
    torch.cuda.synchronize()
    start_time = time.time()
    torch.exp(torch.sin(large_tensor_gpu.clone()))
    torch.cuda.synchronize()
    time_fp32_gpu = time.time() - start_time
    print(f"Float32 on GPU took: {time_fp32_gpu:.6f} seconds")

    large_tensor_gpu_fp16 = large_tensor_gpu.clone().to(torch.float16)
    torch.cuda.synchronize()
    start_time = time.time()
    torch.exp(torch.sin(large_tensor_gpu_fp16))
    torch.cuda.synchronize()
    time_fp16_gpu = time.time() - start_time
    print(f"Float16 on GPU took: {time_fp16_gpu:.6f} seconds (Diabolical Speed!)")

    large_tensor_gpu_bf16 = large_tensor_gpu.clone().to(torch.bfloat16)
    torch.cuda.synchronize()
    start_time = time.time()
    torch.exp(torch.sin(large_tensor_gpu_bf16))
    torch.cuda.synchronize()
    time_bf16_gpu = time.time() - start_time
    print(f"BFloat16 on GPU took: {time_bf16_gpu:.6f} seconds (Also incredibly fast!)")
else:
    print("GPU not available for speed trials. A true pity!")



--- CPU Speed Trials ---
Float32 on CPU took: 0.011775 seconds
Float16 on CPU took: 0.005535 seconds (Often slower on CPU!)
BFloat16 on CPU took: 0.001954 seconds\n
--- GPU Speed Trials ---
Float32 on GPU took: 0.054487 seconds
Float16 on GPU took: 0.000000 seconds (Diabolical Speed!)
BFloat16 on GPU took: 0.000000 seconds (Also incredibly fast!)


## Part 2.1: A Glimpse into the Dark Art of Quantization

Now, my apprentice, I shall grant you a peek into a forbidden realm: **Quantization**. This is the art of shrinking your models, making them faster and more efficient for inference, by converting their weights from high-precision `float32` into low-precision integers, most commonly `int8`!

**Why does this matter?**
Imagine forging a colossal golem (`float32` model) that requires immense energy (memory and compute) to move. Quantization is like transmuting that golem into a swarm of nimble, lightning-fast sprites (`int8` model).

-   **Speed:** Integer arithmetic is vastly faster on modern CPUs and specialized hardware than floating-point math.
-   **Size:** An `int8` model can be **4x smaller** than its `float32` counterpart! This is critical for deploying models on devices with limited memory, like phones or embedded systems.

We are merely scratching the surface here, but let us witness a simple demonstration of this power!



In [None]:
# --- A Simple Quantization Demonstration ---

# 1. Our "Model": A large float32 weight matrix
fp32_weights = torch.randn(2048, 1024)
# Our "Input": A typical input tensor
input_tensor = torch.randn(1, 2048)

# 2. The Quantization Spell
def quantize_tensor(tensor):
    # Find the scale and zero_point
    q_min, q_max = -128, 127
    scale = (tensor.max() - tensor.min()) / (q_max - q_min)
    zero_point = q_min - (tensor.min() / scale)
    zero_point = int(zero_point.round().item())
    
    # Quantize
    quantized_tensor = (tensor / scale + zero_point).round().to(torch.int8)
    return quantized_tensor, scale, zero_point

def dequantize_tensor(quantized_tensor, scale, zero_point):
    return (quantized_tensor.float() - zero_point) * scale

int8_weights, scale, zero_point = quantize_tensor(fp32_weights)

print(f"Original fp32 weights size: {fp32_weights.element_size() * fp32_weights.nelement() / 1024**2:.2f} MB")
print(f"Quantized int8 weights size: {int8_weights.element_size() * int8_weights.nelement() / 1024**2:.2f} MB (4x smaller!)\\n")

# 3. Performance Comparison

# Time float32 matrix multiplication
start_time = time.time()
fp32_output = torch.matmul(input_tensor, fp32_weights)
fp32_time = time.time() - start_time

# Time int8 matrix multiplication (note: PyTorch needs specific functions for this,
# but we simulate the concept by casting for the matmul)
# In a real scenario, you'd use a quantized kernel for max speed.
int8_input = torch.round(input_tensor / scale + zero_point).to(torch.int8) # In reality input is quantized too
start_time = time.time()
# We need to use a function that supports int8 matmul, this is a conceptual demonstration
# For a real speedup, one would use `torch.nn.quantized.functional.linear`
# Here we just cast to a larger int type for the matmul to run
int8_output_simulated = torch.matmul(int8_input.to(torch.int32), int8_weights.to(torch.int32))
int8_time = time.time() - start_time

print(f"Float32 matmul took: {fp32_time:.6f} seconds")
print(f"Simulated Int8 matmul took: {int8_time:.6f} seconds (Conceptually MUCH faster on compatible hardware!)")

# 4. Compare the results
dequantized_output = dequantize_tensor(int8_output_simulated.to(torch.int8), scale, zero_point) # This is not correct way to dequantize the output of matmul
# A proper dequantization of matmul output would be `output * input_scale * weight_scale`
# This is just for demonstration purposes
dequantized_output_proper = (int8_output_simulated.float() * scale * scale) # A simplification

# Let's check the difference
average_diff = torch.mean(torch.abs(fp32_output - dequantized_output_proper))
print(f"\\nAverage difference between fp32 and dequantized output: {average_diff:.4f}")
print("A small price in precision for a monumental gain in speed and size! Mwahahaha!")



Original fp32 weights size: 8.00 MB
Quantized int8 weights size: 2.00 MB (4x smaller!)\n
Float32 matmul took: 0.003999 seconds
Simulated Int8 matmul took: 0.002805 seconds (Conceptually MUCH faster on compatible hardware!)
\nAverage difference between fp32 and dequantized output: 5.5590
A small price in precision for a monumental gain in speed and size! Mwahahaha!


## Your Mission: Forge Your Own Creation!

A true master never ceases to practice. I leave you with these challenges to solidify your newfound power. Do not be afraid to experiment—to the lab!

1.  **The Alchemist's Transmutation**:
    *   Create a `(4, 4)` tensor of random numbers on the `cpu` with the default `float32` dtype.
    *   Print its `dtype` and `device`.
    *   Transmute this tensor into a `bfloat16` tensor and teleport it to the most powerful device your machine possesses.
    *   Print its new `dtype` and `device`.

2.  **The Precision Analyst**:
    *   Create a `float64` tensor containing the number `[3.141592653589793]`.
    *   Cast it down to `float32` and then to `float16`.
    *   Print the value at each stage. Observe the loss of precision as you step down the ladder of alchemy!

3.  **The Device-Aware Calculator**:
    *   Create two large `(1000, 1000)` random tensors, `a` and `b`.
    *   Write a piece of code that checks if a powerful device (`cuda` or `mps`) is available.
    *   If it is, move both tensors to that device before performing a matrix multiplication (`a @ b`).
    *   If not, perform the operation on the `cpu`.
    *   Print the device on which the operation was performed. This is the foundation of all robust neural network code!



In [None]:
# Your code for the final challenges goes here!

# Challenge 1: The Alchemist's Transmutation
print("--- Challenge 1: The Alchemist's Transmutation ---")
cpu_tensor_fp32 = torch.randn(4, 4)
print(f"Original tensor device: {cpu_tensor_fp32.device}, dtype: {cpu_tensor_fp32.dtype}")

# Use the device we selected earlier!
transmuted_tensor = cpu_tensor_fp32.to(device=device, dtype=torch.bfloat16)
print(f"Transmuted tensor device: {transmuted_tensor.device}, dtype: {transmuted_tensor.dtype}\\n")


# Challenge 2: The Precision Analyst
print("--- Challenge 2: The Precision Analyst ---")
pi_fp64 = torch.tensor([3.141592653589793], dtype=torch.float64)
pi_fp32 = pi_fp64.to(torch.float32)
pi_fp16 = pi_fp64.to(torch.float16)

print(f"Float64 (Original): {pi_fp64.item():.15f}")
print(f"Float32 (Reduced):  {pi_fp32.item():.15f}")
print(f"Float16 (Sacrificed): {pi_fp16.item():.15f}\\n")


# Challenge 3: The Device-Aware Calculator
print("--- Challenge 3: The Device-Aware Calculator ---")
a = torch.randn(1000, 1000)
b = torch.randn(1000, 1000)

# We already have our best `device` from the spell cast earlier!
a_on_device = a.to(device)
b_on_device = b.to(device)

# Perform the operation
result = a_on_device @ b_on_device
print(f"Matrix multiplication performed on: {result.device}")



------



## Professor Torchenstein's Outro

Mwahahaha! Do you feel it? The hum of raw computational power at your fingertips? You have transcended the mundane world of default settings and seized control of the very essence of your tensors. You are no longer a mere summoner; you are an **alchemist** and a **dimensional traveler**!

You have learned to choose your weapons—the precise `dtype` for the task at hand and the mightiest `device` for your computations. This knowledge is the bedrock upon which all great neural architectures are built.

But do not rest easy! Our journey has just begun. The tensors are humming, eager for the next lesson where we shall unleash their raw mathematical power with **Elemental Tensor Alchemy**!

Until then, keep your learning rates high and your devices hotter! The future... is computational!

<video controls width="100%"  src="/assets/images/torchenstein_maniacal_laugh_close_up.mp4" title="Professor Torchenstein's maniacal laugh">
  Your browser does not support the video tag. Please update your browser to view this content.
</video>

