In [7]:
import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
class CustomDataset(Dataset):
  def __init__(self, text, labels):
    self.labels = labels
    self.data = text

  def __len__(self):
    return len(self.labels)
  
  def __getitem__(self, idx):
    label = self.labels[idx]
    text = self.data[idx]
    sample = {"Text" : text, "Class" : label}
    return sample

In [6]:
text = ['Happy', 'Amazing', 'Sad', 'Unhappy', 'Glum']
labels = ['P','P','N','N','N']
MyDataset = CustomDataset(text, labels)

In [14]:
MyDataLoader = DataLoader(MyDataset, batch_size=2, shuffle=True)
next(iter(MyDataLoader))

{'Class': ['N', 'N'], 'Text': ['Sad', 'Unhappy']}

In [16]:
for dataset in MyDataLoader :
  print(dataset)

{'Text': ['Amazing', 'Unhappy'], 'Class': ['P', 'N']}
{'Text': ['Glum', 'Happy'], 'Class': ['N', 'P']}
{'Text': ['Sad'], 'Class': ['N']}


In [29]:
from torchvision.datasets import VisionDataset
from typing import Any, Callable, Dict, List, Optional, Tuple
import os

from tqdm import tqdm
import sys
from pathlib import Path
import requests

from skimage import io, transform
import matplotlib.pyplot as plt
import tarfile


In [50]:
class NotMNIST(VisionDataset):
  resource_url = 'http://yaroslavvb.com/upload/notMNIST/notMNIST_large.tar.gz'

  def __init__(
      self,
      root : str,
      train : bool = True,
      transfrom : Optional[Callable] = None,
      target_transform : Optional[Callable] = None,
      download : bool = False
  ) -> None :
    super(NotMNIST, self).__init__(root, transform=transform,
                                    target_transform=target_transform)
    if not self._check_exists() or download:
      self.download()

  def __len__(self):
    return len(self.data)
  
  def __getitem__(self, index):
    image_name = self.data[index]
    image = io.imread(image_name)
    label = self.targets[index]
    if self.transform :
      image = self.transform(image)
    return image, label
      
        
  def _load_data(self):
    filepath = self.image_folder
    data = []
    targets = []

    for target in os.listdir(filepath):
      filenames = [os.path.abspath(
          os.path.join(filepath, target, x)) for x in os.listdir(
              os.path.join(filepath, target))]
      target.extend([target] * len(filenames))
      data.extend(filenames)

    return data, targets

    @prorperty
    def raw_folder(self) -> str :
      return os.path.join(self.root, self.__class__.__name__, 'raw')

    @property
    def image_folder(self) -> str :
      os.makedirs(self.raw_folder, exist_ok=True)
      fname = self.resuorce_url.split("/")[-1]
      chunk_size = 1024

    def download(self) -> None :
      os.makedirs(self.raw_folder, exist_ok=True)
      fname = self.resource_url.split("/")[-1]
      chunk_size = 1024

      filesize = int(requests.head(self.resource_url).headers["Content-Length"])

      with requests.get(self.resource_url, stream=True) as r, open(
          os.path.join(self.raw_folder, fname), 'wb') as f, tqdm(
              unit = 'B',
              unit_scale = True,
              unit_divisor = 1024,
              total = filesize,
              file = sys.stdout,
              desc=fname
          ) as progress :
              for chunk in r.iter_content(chunk_size = chunk_size):
                datasize = f.write(chunk)
                progress.update(datasize)
            
      self._extract_file(os.path.join(self.raw_folder, fname), target_path=self.root)

    def _extract_file(self, fname, target_path) -> None :
      if fname.endswith("tar.gz"):
        tag = "r:gz"
      elif fname.endswith("tar"):
        tag = "r:"
      tar = tarfile.open(fname, tag)
      tar.extractall(path=target_path)
      tar.close()

    def _check_exists(self) -> bool:
        return os.path.exists(self.raw_folder)

In [51]:
dataset = NotMNIST("data", download=True)

KeyError: ignored