<a href="https://colab.research.google.com/github/jinyingtld/python/blob/main/AI6126_tutorial_8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install dependencies: (use cu101 because colab has CUDA 10.1)
!pip install -U torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.5.1+cu101
  Downloading https://download.pytorch.org/whl/cu101/torch-1.5.1%2Bcu101-cp37-cp37m-linux_x86_64.whl (704.4 MB)
[K     |████████████████████████████████| 704.4 MB 1.2 kB/s 
[?25hCollecting torchvision==0.6.1+cu101
  Downloading https://download.pytorch.org/whl/cu101/torchvision-0.6.1%2Bcu101-cp37-cp37m-linux_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 51.5 MB/s 
Installing collected packages: torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 1.10.0+cu111
    Uninstalling torch-1.10.0+cu111:
      Successfully uninstalled torch-1.10.0+cu111
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.11.1+cu111
    Uninstalling torchvision-0.11.1+cu111:
      Successfully uninstalled torchvision-0.11.1+cu111
[31mERROR: pip's dependency resolver does not currently take into account all the packages that

In [6]:
# Create a MNISTM dataset class
# MNISTM isn't supported by torchvision
#https://github.com/liyxi/mnist-m
import os
import warnings

import torch
from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive 


class MNISTM(VisionDataset):
    """
    MNIST-M Dataset.
    """

    resources = [
        ('https://github.com/liyxi/mnist-m/releases/download/data/mnist_m_train.pt.tar.gz',
         '191ed53db9933bd85cc9700558847391'),
        ('https://github.com/liyxi/mnist-m/releases/download/data/mnist_m_test.pt.tar.gz',
         'e11cb4d7fff76d7ec588b1134907db59')     
    ]

    training_file = "mnist_m_train.pt"
    test_file = "mnist_m_test.pt"
    classes = ['0 - zero','1 - one', '2 - two', '3 - three', '4 - four',
               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
    
    def __init__(self, root, train=True, transform=None, target_transform=None, downlaod=False):
        """Init MNIST-M dataset."""
        super(MNISTM, self).__init__(root, transform=transform, target_transform=target_transform)
        self.train=train

        if downlaod:
            self.download()
        
        if not self._check_exists():
            raise RuntimeError("Dataset not found." + 
                               "You can use download=True to download it")
        
        if self.train:
            data_file=self.train_file
        else:
            data_file=self.test_file
        
        print(os.path.join(self.processed_folder,data_file))

        self.data, self.targets = torch.load(os.path.join(self.processed_folder,data_file))
    
    def __getitem__(self, index):
        """Get images and target for data loader.
        Args:
            index (int):Index
        returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.squeeze().numpy(), mode="RGB")

        if self.transform is not None:
            img = self.transform(img)
        
        if self.target_transform is not None:
            target = self.target_transform(target)
        
        return img, target
    
    def __len__(self):
        """Return size of dataset."""
        return len(self.data)

    @property
    def train_labels(self):
        warnings.warn("train_labels has been renamed targets")
        return self.targets
    
    @property
    def test_labels(self):
        warnings.warn("test_labels has been renamed targetrs")
        return self.targets
    
    @property
    def train_data(self):
        warnings.warn("train_data has been renamed data")
        return self.data
    
    @property
    def test_data(self):
        warnings.warn("test_data has been renamed data")
        return self.data
    
    @property
    def raw_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'raw')
    
    @property
    def processed_folder(self):
        return os.path.join(self.root, self.__class__.__name__,'processed')
    
    @property
    def class_to_idx(self):
        return {_class:i for i, _class in enumerate(self.classes)}

    def _check_exists(self):
        return (os.path.exists(os.path.join(self.processed_folder,self.training_file)) and 
                os.path.exists(os.path.join(self.processed_folder, self.test_file)))
    
    def download(self):
        """Download the MNIST-M data"""

        if self._check_exists():
            return
        
        os.makedirs(self.raw_folder, exist_ok=True)
        os.makedirs(self.processed_folder, exist_ok=True)
    
        # download files
        for url, md5 in self.resources:
            filename = url.rpartition('/')[2]
            download_and_extract_archive(url, download_root=self.raw_folder,
                                         extract_root=self.processed_folder,
                                         filename=filename, md5=md5)
        print("Done!")
    
    def extra_repr(self):
        return "Split: {}".format("Train" if self.train is True else "Test")
    


In [7]:
# Import packages
import torch 
import torchvision
from torch import nn
from torch.autograd import Function
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms 
import matplotlib.pyplot as plt 
import numpy as np 
from tqdm import tqdm # show progress bar 

In [8]:
# Make sure cuda is available 
print(f"CUDA is available: {torch.cuda.is_available()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CUDA is available: False


In [None]:
# Download datasets
# We use MNIST and MNIST-M in the tutorial
# MNIST: hand-written digits (28x28)
# MNIST-M: MNIST's background blended with random color patches
tfm1 = transforms.Compose([
    transforms.Grayscale(3), # convert to 3-channel
    transforms.ToTensor(),
    transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081))
])

tfm2 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081))
])

mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=tfm1)
mnist_test = datasets.MNIST(root="./data", train=False, download=True, transform=tfm1)

