In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

%matplotlib inline 

AttributeError: module 'torch' has no attribute 'QUInt4x2Storage'

## Loading FashionMNIST dataset

In [None]:
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

## Iterate and visualize the dataset

In [None]:
# Create a dictionary for labelset
labelmap = {
    0: "T-shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat", 
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag", 
    9: "Ankle boot",
}

In [None]:
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols*rows+1):
    sample_idx = torch.randint(len(training_data), size= (1,)).item()
    training_img, training_lbl = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labelmap[training_lbl])
    plt.axis('off')
    plt.imshow(training_img.squeeze(), cmap='gray')
plt.show()

# Creating a custom dataset

In [None]:
import pandas as pd
from torchvision.io import read_image

In [None]:
class customdataset(Dataset):
    def __init__(self, img_dir, annotations_file, img_transform=None, tgt_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.img_transform = img_transform
        self.tgt.transform = tgt_transform
        
    def __len__(self):
        return len(self.img_labels)
        
    def __getitem__(self, idx):
        img_path = os.path.join(img_dir, self.img_labels.iloc[idx, 0])
        ip_img = read_csv(img_path)
        tgt_lbl = self.img_labels.iloc[idx, 1]
        if img_transform:
            ip_img = self.img_transform(ip_img)
        if self.tgt_transform:
            target = self.tgt_transform(tgt_lbl)
        return ip_img, target
        

## Dataloaders to iterate over the dataset

In [None]:
train_dl = DataLoader(training_data, batch_size=64, shuffle=True)
test_dl = DataLoader(test_data, batch_size=64, shuffle=False)

### Display some samples

In [None]:
train_imgs, train_tgts = next(iter(train_dl))
print("Size of training batch {}".format(train_imgs.size()))
print("Size of test batch {}".format(train_tgts.size()))
img = train_imgs[0].squeeze()
lbl = train_tgts[0]
plt.imshow(img)
plt.title(lbl)
plt.show()