##  **1. Custom Dataset Class**

You define your dataset by subclassing `torch.utils.data.Dataset` and overriding `__len__()` and `__getitem__()`.

```python
import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data  # e.g., a NumPy array or tensor
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx]
        return x, y
```

---

##  **2. `torch.utils.data.random_split`** 

`torch.utils.data.random_split(dataset, lengths, generator=None)`: splits a dataset into non-overlapping new datasets of given lengths.

---



#### **2.1. Only Index-Based – No Data Copy**

* `random_split` **does not copy** the underlying data.
* It **wraps the original dataset** and uses internally shuffled indices to simulate subsets.
* Memory usage is minimal because it's just a view via `Subset`.

Example:

```python
from torch.utils.data import random_split

train_ds, val_ds = random_split(full_dataset, [8000, 2000])
# `train_ds` and `val_ds` are `Subset` objects
```

---

####  **2.2. Reproducibility with Generator**

To ensure reproducibility (same split every run), pass a seeded `torch.Generator`:

```python
import torch
generator = torch.Generator().manual_seed(42)

train_ds, val_ds = random_split(full_dataset, [8000, 2000], generator=generator)
```

If you don't pass a generator, a random seed is used from the system, and results will vary across runs.

---

####  **2.3. How It Works Internally**

* Internally, it:

  * Shuffles indices using the generator (if given),
  * Splits them into the specified sizes,
  * Creates `Subset(dataset, indices)` for each split.

---

####  **2.4.Common Pitfalls**

* **Don't modify the original dataset in-place** after splitting. The splits reference it.
* Be careful with imbalanced class distributions — `random_split` does **not** preserve class ratios.

---



```python
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import random_split

dataset = CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())

generator = torch.Generator().manual_seed(123)

train_set, val_set = random_split(dataset, [45000, 5000], generator=generator)

print(len(train_set), len(val_set))  # 45000 5000
```

## **3.`torch.utils.data.Subset`**

`Subset` creates a **view** of a dataset using a list of indices. It’s a wrapper that lets you work with just a portion of a dataset **without copying** the data.

```python
torch.utils.data.Subset(dataset, indices)
```
---



#### 3.1. **No Data Copy**

* Like `random_split`, `Subset` does **not duplicate data** — it just stores references (indices).
* It’s memory-efficient and fast.

#### 3.2. **How It Works**

* Internally, `Subset` defines `__getitem__` like this:

  ```python
  def __getitem__(self, idx):
      return self.dataset[self.indices[idx]]
  ```
* So each item access fetches from the original dataset using the provided index mapping.

---

#### **3.3 Use Cases**

**Manual Splits**

Useful if you want custom train/val/test splits:

```python
from torch.utils.data import Subset
indices = list(range(len(dataset)))
train_idx = indices[:8000]
val_idx = indices[8000:]

train_ds = Subset(dataset, train_idx)
val_ds = Subset(dataset, val_idx)
```

**Stratified Splits with scikit-learn**

You can use `StratifiedShuffleSplit` to split based on labels and then wrap them in `Subset`:

```python
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import Subset

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
targets = dataset.targets  # Or dataset.labels depending on the dataset

for train_idx, val_idx in sss.split(X=targets, y=targets):
    train_ds = Subset(dataset, train_idx)
    val_ds = Subset(dataset, val_idx)
```

---

**Accessing Original Dataset**

You can still access the original dataset and indices:

```python
subset.dataset   # Original dataset
subset.indices   # List of indices used
```

---

**Summary Comparison: `random_split` vs `Subset`**

| Feature         | `random_split`                     | `Subset`                             |
| --------------- | ---------------------------------- | ------------------------------------ |
| Purpose         | Randomly divide dataset into parts | Create a view on dataset via indices |
| Data Copy?      | ❌ No                               | ❌ No                                 |
| Reproducibility | ✅ With `torch.Generator`           | ✅ If indices are controlled          |
| Shuffling?      | ✅ Internally when splitting        | ❌ You provide indices                |
| Use Cases       | Simple, fast random split          | Custom, stratified, or fixed split   |

---

##  **ImageFolder**


In PyTorch, `torchvision.datasets.ImageFolder` is a utility class for loading image datasets arranged in a specific directory structure. It automatically assigns labels based on subdirectory names, making it ideal for classification tasks.

---

**Directory Structure**

`ImageFolder` expects the dataset directory to be structured like this:

```
root/
    class1/
        img1.png
        img2.png
        ...
    class2/
        img3.png
        img4.png
        ...
```

* Each **subfolder** under `root` is treated as a class.
* All images inside a class folder are treated as samples of that class.

---

**How It Works**

```python
from torchvision import datasets, transforms

# Define optional transforms (resizing, normalization, etc.)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Load dataset
dataset = datasets.ImageFolder(root='path/to/root', transform=transform)
```

---

**Labels and Classes**

* `dataset.classes`: list of class names (e.g., `['cat', 'dog']`)
* `dataset.class_to_idx`: dict mapping class names to label indices (e.g., `{'cat': 0, 'dog': 1}`)
* Each sample is a tuple: `(image_tensor, label)`

You can access an image and its label like this:

```python
img, label = dataset[0]
```
---