📝 **Author:** Amirhossein Heydari - 📧 **Email:** <amirhosseinheydari78@gmail.com> - 📍 **Origin:** [mr-pylin/pytorch-workshop](https://github.com/mr-pylin/pytorch-workshop)

---


**Table of contents**<a id='toc0_'></a>    
- [Dependencies](#toc1_)    
- [Dataset](#toc2_)    
  - [Load Iris Dataset](#toc2_1_)    
- [Torch Dataset](#toc3_)    
  - [Built-in Datasets](#toc3_1_)    
    - [General-Purpose Dataset Wrappers](#toc3_1_1_)    
      - [TensorDataset](#toc3_1_1_1_)    
      - [Subset](#toc3_1_1_2_)    
      - [RandomSplit](#toc3_1_1_3_)    
      - [ConcatDataset](#toc3_1_1_4_)    
      - [ChainDataset](#toc3_1_1_5_)    
      - [ImageFolder](#toc3_1_1_6_)    
      - [DatasetFolder](#toc3_1_1_7_)    
    - [Predefined Datasets](#toc3_1_2_)    
      - [Torchvision Built-in Datasets](#toc3_1_2_1_)    
      - [Torchaudio Built-in Datasets](#toc3_1_2_2_)    
  - [Custom Datasets](#toc3_2_)    
    - [Base Class: Dataset](#toc3_2_1_)    
    - [Base Class: VisionDataset](#toc3_2_2_)    
    - [Base Class: VisionDataset's Derived Class](#toc3_2_3_)    
- [Torch DataLoader](#toc4_)    
  - [Strategies for updating weights](#toc4_1_)    
    - [Batch Gradient Descent](#toc4_1_1_)    
    - [Stochastic Gradient Descent](#toc4_1_2_)    
    - [Mini-Batch Gradient Descent](#toc4_1_3_)    
- [Torch DataParallel](#toc5_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# <a id='toc1_'></a>[Dependencies](#toc0_)


In [1]:
from math import ceil
from pathlib import Path
from typing import Callable

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import ChainDataset, ConcatDataset, DataLoader, Dataset, Subset, TensorDataset, random_split
from torchaudio.datasets import YESNO
from torchvision.datasets import MNIST, DatasetFolder, ImageFolder, VisionDataset
from torchvision.io import decode_image
from torchvision.transforms import v2

In [None]:
# update paths as needed based on your project structure
MNIST_DIR = r"../../datasets"
YESNO_DIR = r"../../datasets"

# <a id='toc2_'></a>[Dataset](#toc0_)

$
X = \begin{bmatrix}
        x_{1}^1 & x_{1}^2 & \cdots & x_{1}^n \\
        x_{2}^1 & x_{2}^2 & \cdots & x_{2}^n \\
        \vdots & \vdots & \ddots & \vdots \\
        x_{m}^1 & x_{m}^2 & \cdots & x_{m}^n \\
    \end{bmatrix}_{m \times n} \quad \text{(m: number of samples, n: number of features)}
$

$
Y = \begin{bmatrix}
        y_{1} \\
        y_{2} \\
        \vdots \\
        y_{m} \\
    \end{bmatrix}_{m \times 1} \quad \text{(m: number of samples)}
$


## <a id='toc2_1_'></a>[Load Iris Dataset](#toc0_)


In [None]:
iris_dataset_url = (
    r"https://raw.githubusercontent.com/mr-pylin/datasets/refs/heads/main/data/tabular-data/iris/dataset.csv"
)

# pandas data-frame
df = pd.read_csv(iris_dataset_url, encoding="utf-8")

# log
df.head()

In [None]:
classes = df["class"].unique()
class_to_idx = {l: i for i, l in enumerate(classes)}

# split dataset into features and labels
X, y = df.iloc[:, :4].values, df.iloc[:, 4].values

# convert categorical labels into indices
y = np.array([class_to_idx[l] for l in y])

# properties of the dataset
num_samples, num_features = X.shape
classes, samples_per_class = np.unique(y, return_counts=True)

# log
print(f"X.shape: {X.shape}")
print(f"X.dtype: {X.dtype}")
print(f"y.shape: {y.shape}")
print(f"y.dtype: {y.dtype}")
print("-" * 50)
print(f"classes          : {classes}")
print(f"samples per class: {samples_per_class}")

In [None]:
# convert numpy.ndarray to torch.Tensor
X = torch.from_numpy(X.astype(np.float32))
y = torch.from_numpy(y.astype(np.float32)).view(-1, 1)

# log
print(f"x.shape: {X.shape}")
print(f"x.dtype: {X.dtype}")
print(f"x.ndim : {X.ndim}\n")
print(f"y.shape: {y.shape}")
print(f"y.dtype: {y.dtype}")
print(f"y.ndim : {y.ndim}")

# <a id='toc3_'></a>[Torch Dataset](#toc0_)

- a **Dataset** is an abstraction that represents a collection of data samples.


## <a id='toc3_1_'></a>[Built-in Datasets](#toc0_)

- It provides a standardized way to **access** and **manipulate** data, making it easier to work with different types of datasets.

📝 **Docs & Tutorials** 📚:

- `torch.utils.data`: [pytorch.org/docs/stable/data.html](https://pytorch.org/docs/stable/data.html)
- Datasets & DataLoaders: [pytorch.org/tutorials/beginner/basics/data_tutorial.html](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)
- ImageFolder: [pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html](https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html)
- DatasetFolder: [pytorch.org/vision/stable/generated/torchvision.datasets.DatasetFolder.html](https://pytorch.org/vision/stable/generated/torchvision.datasets.DatasetFolder.html)


### <a id='toc3_1_1_'></a>[General-Purpose Dataset Wrappers](#toc0_)

- They provide flexible tools to efficiently structure, manipulate, and process data.
- These classes are particularly useful when dealing with existing datasets, combining multiple datasets without writing custom classes.

#### <a id='toc3_1_1_1_'></a>[TensorDataset](#toc0_)

- The `TensorDataset` class is part of the `torch.utils.data` module and allows you to **create** a dataset from one or more **tensors**.
- It assumes that the **first dimension** of each tensor corresponds to the **number of samples**, and it pairs the tensors together to form **samples**.


In [None]:
dataset_1 = TensorDataset(X, y)

# log
print(f"type(dataset_1) : {type(dataset_1)}")
print(f"len(dataset_1)  : {len(dataset_1)}")
print(f"dataset_1[0][0] : {dataset_1[0][0]}")
print(f"dataset_1[0][1] : {dataset_1[0][1]}")

#### <a id='toc3_1_1_2_'></a>[Subset](#toc0_)

- The `Subset` class is part of the `torch.utils.data` module and allows you to **create** a dataset from a selected **subset** of a larger dataset.
- It requires a list of **indices** that specify which samples to include, enabling **train-validation-test** splits or **selecting** specific data points.


In [7]:
# Iris-setosa
subset_1 = Subset(dataset_1, indices=range(50))

# Iris-versicolor
subset_2 = Subset(dataset_1, indices=range(50, 100))

# Iris-virginica
subset_3 = Subset(dataset_1, indices=range(100, 150))

In [None]:
# log
print(f"type(subset_1)                        : {type(subset_1)}")
print(f"len(subset_1)                         : {len(subset_1)}")
print(f"subset_1.indices                      : {subset_1.indices}")
print(f"subset_1.dataset                      : {subset_1.dataset}")
print(f"subset_1[0]                           : {subset_1[0]}")
print(f"subset_1.dataset[subset_1.indices[0]] : {subset_1.dataset[subset_1.indices[0]]}")

#### <a id='toc3_1_1_3_'></a>[RandomSplit](#toc0_)

- The `random_split` function is part of the `torch.utils.data` module and allows you to **randomly** divide a dataset into **non-overlapping** subsets.
- It takes a dataset and a list of **split sizes**, ensuring a randomized but consistent division across multiple runs.


In [None]:
train_ratio, val_ratio, test_ratio = (0.8, 0.1, 0.1)
train_set_1, val_set_1, test_set_1 = random_split(dataset_1, [train_ratio, val_ratio, test_ratio])

# log
print(f"type(train_set_1)   : {type(train_set_1)}")
print(f"len(train_set_1)    : {len(train_set_1)}")
print(f"train_set_1.indices : {train_set_1.indices}")
print(f"train_set_1.dataset : {train_set_1.dataset}")
print(f"train_set_1[0]      : {train_set_1[0]}")

In [None]:
train_length, val_length, test_length = (120, 15, 15)
train_set_2, val_set_2, test_set_2 = random_split(dataset_1, [train_ratio, val_ratio, test_ratio])

# log
print(f"type(test_set_2)   : {type(test_set_2)}")
print(f"len(test_set_2)    : {len(test_set_2)}")
print(f"test_set_2.indices : {test_set_2.indices}")
print(f"test_set_2.dataset : {test_set_2.dataset}")
print(f"test_set_2[0]      : {test_set_2[0]}")

#### <a id='toc3_1_1_4_'></a>[ConcatDataset](#toc0_)

- The `ConcatDataset` class is part of the `torch.utils.data` module and allows you to **merge** multiple datasets into a single dataset.
- It assumes that all datasets have the same **structure**, making it useful for combining different datasets for training.


In [None]:
dataset_2 = ConcatDataset([train_set_1, val_set_1, test_set_1])

# log
print(f"type(dataset_2)            : {type(dataset_2)}")
print(f"len(dataset_2)             : {len(dataset_2)}")
print(f"dataset_2.cumulative_sizes : {dataset_2.cumulative_sizes}")
print(f"dataset_2.datasets         : {dataset_2.datasets}")
print(f"dataset_2[0]               : {dataset_2[0]}")

#### <a id='toc3_1_1_5_'></a>[ChainDataset](#toc0_)

- The `ChainDataset` class is part of the `torch.utils.data` module and allows you to **iterate** sequentially over multiple datasets **without** merging them.
- It doesn't index elements but instead iterates over datasets in sequence (datasets must be iterables).
- This is efficient for streaming **large** datasets when merging them into memory is impractical.

✍️ **Note**:

- `ChainDataset` in PyTorch only supports datasets that are instances of `IterableDataset`!


In [None]:
# train_set_1, val_set_1, test_set_1 are instances from `Dataset` instead of `IterableDataset`
# so this code is not practical and it's just for demonstration purpose
dataset_3 = ChainDataset([train_set_1, val_set_1, test_set_1])

# log
print(f"type(dataset_3)    : {type(dataset_3)}")
print(f"dataset_3.datasets : {dataset_3.datasets}")

#### <a id='toc3_1_1_6_'></a>[ImageFolder](#toc0_)

- The `ImageFolder` class is a subclass of `DatasetFolder` and part of the `torchvision.datasets` that allows you to load **image** datasets organized in **folders**.
- The **folder** structure is expected to have **subfolders**, each representing a class. Each subfolder should contain images of the corresponding class.
- This is commonly used for loading datasets in image classification tasks.

✍️ **Note**:

- `ImageFolder` automatically labels images based on the **folder names**, making it convenient for supervised learning tasks.


In [None]:
try:
    image_dataset_1 = ImageFolder(root="/path/to/data", transform=None)
except FileNotFoundError as e:
    print(e)

#### <a id='toc3_1_1_7_'></a>[DatasetFolder](#toc0_)

- The `DatasetFolder` class is a subclass of `VisionDataset` and part of the `torchvision.datasets` designed to load a dataset where the images are organized in a **custom** manner.
- Unlike `ImageFolder`, it allows for more **flexibility** in terms of dataset organization but still requires that images be placed in a **folder-based** structure.

✍️ **Note**:

- `DatasetFolder` requires a **custom loader** function, which can be used to define how images are **read**, **processed**, and **returned**.


In [None]:
def custom_loader(path):
    return Image.open(path).convert("RGB")


try:
    image_dataset_2 = DatasetFolder(root="/path/to/data", loader=custom_loader, extensions=".jpg", transform=None)
except FileNotFoundError as e:
    print(e)

### <a id='toc3_1_2_'></a>[Predefined Datasets](#toc0_)

- datasets designed for specific benchmarks, with automatic downloading and processing.

📝 **Docs**:

- Torchvision Built-in Datasets: [pytorch.org/vision/stable/datasets.html](https://pytorch.org/vision/stable/datasets.html)
- Torchaudio Built-in Datasets: [pytorch.org/audio/stable/datasets.html](https://pytorch.org/audio/stable/datasets.html)


#### <a id='toc3_1_2_1_'></a>[Torchvision Built-in Datasets](#toc0_)

- Torchvision provides many built-in datasets in the `torchvision.datasets` module, as well as utility classes for building your own datasets.
- **Categories of built-in datasets**:
  - Image classification
  - Image Detection or Segmentation
  - Optical Flow
  - Stereo Matching
  - Image pairs
  - Image captioning
  - Video classification
  - Video prediction

ℹ️ **Learn more**:

- details about transforms: [**vision-transform.ipynb**](./vision-transform.ipynb)

In [None]:
mnist_dataset = MNIST(MNIST_DIR, train=True, transform=v2.ToImage(), download=False)

# log
print(f"len(mnist_dataset)        : {len(mnist_dataset)}")
print(f"mnist_dataset[0][0].shape : {mnist_dataset[0][0].shape}")
print(f"mnist_dataset[0][0].dtype : {mnist_dataset[0][0].dtype}")
print(f"classes                   : {mnist_dataset.classes}")

#### <a id='toc3_1_2_2_'></a>[Torchaudio Built-in Datasets](#toc0_)

- Torchaudio provides several built-in datasets in the `torchaudio.datasets` module.


In [None]:
yesno_dataset = YESNO(YESNO_DIR, download=False)

# log
print(f"len(yesno_dataset)        : {len(yesno_dataset)}")
print(f"yesno_dataset[0][0].shape : {yesno_dataset[0][0].shape}")
print(f"yesno_dataset[0][0].dtype : {yesno_dataset[0][0].dtype}")
print(f"sample_rate               : {yesno_dataset[0][1]}")

## <a id='toc3_2_'></a>[Custom Datasets](#toc0_)

-  A custom dataset is useful when you need to handle non-standard data formats, apply specific preprocessing, or create new dataset structures beyond what built-in datasets provide (e.g., ImageFolder, MNIST).

📝 **Docs**:

- `torch.utils.data.Dataset`: [pytorch.org/docs/stable/data.html#torch.utils.data.Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)
- `torch.utils.data.IterableDataset`: [pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset)
- `torchvision.datasets.VisionDataset`: [pytorch.org/vision/main/generated/torchvision.datasets.VisionDataset.html](https://pytorch.org/vision/main/generated/torchvision.datasets.VisionDataset.html)


### <a id='toc3_2_1_'></a>[Base Class: Dataset](#toc0_)

- The `Dataset` class is part of the `torch.utils.data` module and serves as the foundation for **loading data** in PyTorch.
- Use `Dataset` as the **parent** class and **implement** two key methods:
  - `__len__`: Returns the number of samples in the dataset.
  - `__getitem__`: Retrieves a **single** sample from the dataset at a given **index**.


In [17]:
class BaseDataset(Dataset):
    def __init__(self, data: torch.Tensor, labels: torch.Tensor):
        self.data = data
        self.labels = labels

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
        return self.data[index], self.labels[index]

In [None]:
dataset_4 = BaseDataset(X, y)

# log
print(f"type(dataset_4) : {type(dataset_4)}")
print(f"len(dataset_4)  : {len(dataset_4)}")
print(f"dataset_4[0][0] : {dataset_4[0][0]}")
print(f"dataset_4[0][1] : {dataset_4[0][1]}")

### <a id='toc3_2_2_'></a>[Base Class: VisionDataset](#toc0_)

- The `VisionDataset` class, part of the `torchvision.datasets` module, extends `Dataset` and provides additional functionality for **image-based datasets**.
- It is designed for **loading, processing, and transforming images** efficiently.
- Use `VisionDataset` as the **parent** class and **implement** two key methods:
  - `__len__`: Returns the number of samples in the dataset.
  - `__getitem__`: Retrieves a **single** sample from the dataset at a given **index**.

✍️ **Note:**  
- `VisionDataset` provides a `loader` function for image reading, making it preferable for vision tasks.
- It is the base class for standard datasets like `ImageFolder` and `DatasetFolder`.

In [19]:
def default_loader(path: Path) -> torch.Tensor:
    try:
        image = decode_image(path)
        if image.shape[0] == 1:
            image = image.repeat(3, 1, 1)
        return image
    except Exception as e:
        print(f"Error loading image {path}: {e}")
        return torch.zeros((3, 224, 224))  # return a blank image as fallback

In [None]:
class CustomImageDataset(VisionDataset):
    def __init__(
        self,
        root: str,
        loader: Callable,
        transform: v2.Compose | None = None,
        cache_images: bool = False,
    ):
        super().__init__(root, transform=transform)

        self.root = Path(root)
        self.loader = loader
        self.transform = transform
        self.cache_images = cache_images
        self.classes = [cls.name for cls in self.root.iterdir() if cls.is_dir()]
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        self.img_paths: list[tuple[Path, int]] = []

        # fetch and store (image_path, class_index) pairs
        for cls_name in self.classes:
            cls_dir = self.root / cls_name
            for img_path in cls_dir.glob("*.jpg"):
                self.img_paths.append((img_path, self.class_to_idx[cls_name]))

        # cache for storing images in memory
        self.image_cache = {}

    def __len__(self) -> int:
        return len(self.img_paths)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        img_path, label = self.img_paths[idx]

        # check if the image is already cached
        if self.cache_images and idx in self.image_cache:
            # fetch image from cache
            image = self.image_cache[idx]
        else:
            # load image
            image = self.loader(img_path)

            # cache the image if caching is enabled
            if self.cache_images:
                self.image_cache[idx] = image

        # apply transformations (if any)
        if self.transform:
            image = self.transform(image)

        # convert label to tensor of type torch.long
        label = torch.tensor(label, dtype=torch.long)

        return image, label

### <a id='toc3_2_3_'></a>[Base Class: VisionDataset's Derived Class](#toc0_)

- A derived class from an existing dataset (e.g., `MNIST` from `torchvision.datasets`) extends a predefined dataset while retaining its functionalities.
- Use an existing dataset as the **parent** class and **override** key methods as needed:
  - `__getitem__`: Modifies the way a **single** sample is retrieved (e.g., changing labels, applying custom transformations).
  - Optionally, override `__len__` if the dataset size needs adjustment.
- This approach leverages built-in dataset loading while allowing for customization.


In [21]:
class CustomMNIST(MNIST):
    def __init__(self, root: str, train: bool = True, transform=None, target_transform=None, download: bool = True):
        super().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download)

    def __getitem__(self, index: int):
        image, label = super().__getitem__(index)
        modified_label = label if label != 9 else 0
        return image, modified_label

    @property
    def raw_folder(self) -> str:
        return Path(self.root) / self.__class__.__bases__[0].__name__ / "raw"

In [None]:
dataset_5 = CustomMNIST(MNIST_DIR, train=True, transform=None, target_transform=None, download=False)

# log
print(f"type(dataset_5) : {type(dataset_5)}")
print(f"len(dataset_5)  : {len(dataset_5)}")

In [None]:
# plot
num_images = 10
_, axs = plt.subplots(
    nrows=1, ncols=num_images, figsize=(num_images * 1.5, num_images / (num_images / 2)), layout="compressed"
)
for i in range(num_images):
    axs[i].imshow(dataset_5[i][0], cmap="gray")
    axs[i].set(title=dataset_5[i][1])
    axs[i].axis("off")
plt.show()

# <a id='toc4_'></a>[Torch DataLoader](#toc0_)

- A DataLoader(`torch.utils.data.DataLoader`) is a utility for training and evaluation in deep learning tasks that enables:
  - efficient loading datasets
  - handling batching
  - shuffling
  - parallel data loading

📝 **Docs**:

- `torch.utils.data.DataLoader`: [pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)


## <a id='toc4_1_'></a>[Strategies for updating weights](#toc0_)


### <a id='toc4_1_1_'></a>[Batch Gradient Descent](#toc0_)

- Uses the **entire** dataset to compute the **gradient** of the loss function and **update** the **weights**.
- **Pros**: Provides a stable convergence.
- **Cons**: Can be very slow and computationally expensive for large datasets.

🌟 **Example**:

| #Epoch | batch size | #batch per epoch                    | #iteration per epoch                |
|:------:|:----------:|:-----------------------------------:|:-----------------------------------:|
| $ 2 $  | $ 150 $    | $ \lceil\frac{150}{150}\rceil = 1 $ | $ \lceil\frac{150}{150}\rceil = 1 $ |


In [None]:
epochs = 2
batch_size = len(dataset_1)
dataloader = DataLoader(dataset_1, batch_size=batch_size, shuffle=True, num_workers=2)

# log
for epoch in range(epochs):
    print(f"epoch {epoch+1:0{len(str(epochs))}}/{epochs}")
    for i, (x, y) in enumerate(dataloader):
        print(f"    iteration {i+1}/{ceil(len(dataset_1)/batch_size)}")
        print(f"        x.shape: {x.shape}")
        print(f"        y.shape: {y.shape}")
        print("    weights are updated.")
    print(f"model saw the entire dataset.")
    print("-" * 50)

### <a id='toc4_1_2_'></a>[Stochastic Gradient Descent](#toc0_)

- the model **updates** the **weights** using only **one data point** at a time.
- **Pros**: Faster updates and can escape local minima.
- **Cons**: Can be noisy and may not converge as smoothly as batch gradient descent.

🌟 **Example**:

| #Epoch | batch size | #batch per epoch                    | #iteration per epoch                |
|:------:|:----------:|:-----------------------------------:|:-----------------------------------:|
| $ 2 $  | $ 1 $      | $ \lceil\frac{150}{1}\rceil = 150 $ | $ \lceil\frac{150}{1}\rceil = 150 $ |


In [None]:
epochs = 2
batch_size = 1
dataloader = DataLoader(dataset_1, batch_size=batch_size, shuffle=True, num_workers=2)

# log
for epoch in range(epochs):
    print(f"epoch {epoch+1:0{len(str(epochs))}}/{epochs}")
    for i, (x, y) in enumerate(dataloader):
        if i % 50 == 0 or i == len(X) - 1:
            print(f"    iteration {i+1}/{ceil(len(dataset_1)/batch_size)}")
            print(f"        x.shape: {x.shape}")
            print(f"        y.shape: {y.shape}")
            print("    weights are updated.")
    print(f"model saw the entire dataset.")
    print("-" * 50)

### <a id='toc4_1_3_'></a>[Mini-Batch Gradient Descent](#toc0_)

- the model updates its weights after processing a small batch of 'm' samples from the training dataset.
- this method combines the advantages of both SGD and Batch Gradient Descent by providing a balance between efficiency and stability during training.

🌟 **Example**:

| #Epoch | batch size | #batch                             | #iteration per epoch               |
|:------:|:----------:|:----------------------------------:|:----------------------------------:|
| $ 2 $  | $ 32 $     | $ \lceil\frac{150}{32}\rceil = 5 $ | $ \lceil\frac{150}{32}\rceil = 5 $ |


In [None]:
epochs = 2
batch_size = 32
dataloader = DataLoader(dataset_1, batch_size=batch_size, shuffle=True, num_workers=2)

# log
for epoch in range(epochs):
    print(f"epoch {epoch+1:0{len(str(epochs))}}/{epochs}")
    for i, (x, y) in enumerate(dataloader):
        print(f"    iteration {i+1}/{ceil(len(dataset_1)/batch_size)}")
        print(f"        x.shape: {x.shape}")
        print(f"        y.shape: {y.shape}")
        print("    weights are updated.")
    print(f"model saw the entire dataset.")
    print("-" * 50)

# <a id='toc5_'></a>[Torch DataParallel](#toc0_)

- A DataParallel(`torch.nn.DataParallel`), enables data-level parallelism by distributing input data across multiple GPUs and aggregating the results.
  - Splits the input batch across GPUs.
  - Each GPU processes its split independently.
  - Results from all GPUs are combined on the primary (default) GPU.

📝 **Docs**:

- `nn.DataParallel`: [pytorch.org/docs/stable/generated/torch.nn.DataParallel.html](https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html)

📚 **Tutorials**:

- Data Parallelism: [pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html](https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html)


In [27]:
epochs = 2
batch_size = 30
dataloader = DataLoader(dataset_1, batch_size=batch_size, shuffle=True, num_workers=2)
device = torch.device("cuda" if torch.cuda.is_available else "cpu")

In [28]:
class CustomModel(nn.Module):

    def __init__(self, in_features, out_features):
        super().__init__()
        self.fc = nn.Linear(in_features, out_features)

    def forward(self, input):
        output = self.fc(input)
        print(f"  inside model -> input size: {input.size()} | output size: {output.size()}")
        return output


model = CustomModel(4, 3).to(device)

In [29]:
if torch.cuda.device_count() > 1:
    # [32, xxx] -> [15, ...], [15, ...] on 2 GPUs
    # [32, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    # ...
    model = nn.DataParallel(model)

In [None]:
# log
print(f"number of GPUs: {torch.cuda.device_count()}")
for epoch in range(epochs):
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        y_pred = model(x)