In [4]:
import shutil,os,torch,random,datasets,math
import fastcore.all as fc, numpy as np, matplotlib as mpl, matplotlib.pyplot as plt
import torchvision.transforms as T
import torchvision.transforms.functional as TF,torch.nn.functional as F

from torch.utils.data import DataLoader,default_collate
from pathlib import Path
from torch.nn import init
from fastcore.foundation import L
from torch import nn,tensor
from operator import itemgetter
from torcheval.metrics import MulticlassAccuracy
from functools import partial
from torch.optim import lr_scheduler
from torch import optim
from torchvision.io import read_image,ImageReadMode
from glob import glob

from miniai.datasets import *
from miniai.conv import *
from miniai.learner import *
from miniai.activations import *
from miniai.init import *
from miniai.sgd import *
from miniai.resnet import *
from miniai.augment import *
#from miniai.accel import *
from miniai.training import *

In [5]:
from fastprogress import progress_bar

In [6]:
torch.set_printoptions(precision=5, linewidth=140, sci_mode=False)
mpl.rcParams['figure.dpi'] = 70

set_seed(42)
if fc.defaults.cpus>8: fc.defaults.cpus=8

In [7]:
path_data = Path('data')
path_data.mkdir(exist_ok=True)
path = path_data/'tiny-imagenet-200'

In [None]:
url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
if not path.exists():
    path_zip = fc.urlsave(url, path_data)
    shutil.unpack_archive('data/tiny-imagenet-200.zip', 'data')

In [11]:
bs = 512

In [12]:
class TinyDS:
    def __init__(self, path):
        self.path = Path(path)
        self.files = glob(str(path/'**/*.JPEG'), recursive=True)
    
    def __len__(self): return len(self.files)
    def __getitem__(self,i): return self.files[i], Path(self.files[i].parent.parent.name)

tds = TinyDS(path/'train')

In [14]:
path_anno = path/'val'/'val_annotations.txt'
anno = dict(o.split('\t')[:2] for o in path_anno.read_text().splitlines())
anno

{'val_0.JPEG': 'n03444034',
 'val_1.JPEG': 'n04067472',
 'val_2.JPEG': 'n04070727',
 'val_3.JPEG': 'n02808440',
 'val_4.JPEG': 'n02808440',
 'val_5.JPEG': 'n04399382',
 'val_6.JPEG': 'n04179913',
 'val_7.JPEG': 'n02823428',
 'val_8.JPEG': 'n04146614',
 'val_9.JPEG': 'n02226429',
 'val_10.JPEG': 'n04371430',
 'val_11.JPEG': 'n07753592',
 'val_12.JPEG': 'n02226429',
 'val_13.JPEG': 'n03770439',
 'val_14.JPEG': 'n02056570',
 'val_15.JPEG': 'n02906734',
 'val_16.JPEG': 'n02125311',
 'val_17.JPEG': 'n04486054',
 'val_18.JPEG': 'n04285008',
 'val_19.JPEG': 'n03763968',
 'val_20.JPEG': 'n03814639',
 'val_21.JPEG': 'n03837869',
 'val_22.JPEG': 'n01983481',
 'val_23.JPEG': 'n01629819',
 'val_24.JPEG': 'n04532670',
 'val_25.JPEG': 'n04074963',
 'val_26.JPEG': 'n04540053',
 'val_27.JPEG': 'n04371430',
 'val_28.JPEG': 'n02906734',
 'val_29.JPEG': 'n02094433',
 'val_30.JPEG': 'n03796401',
 'val_31.JPEG': 'n07614500',
 'val_32.JPEG': 'n03837869',
 'val_33.JPEG': 'n03937543',
 'val_34.JPEG': 'n065963

In [18]:
class TinyValDS(TinyDS):
    def __getitem__(self, i): return self.files[i], anno[os.path.basename(self.files[i])]

In [20]:
vds = TinyValDS(path/'val')

In [23]:
class TfmDS:
    def __init__(self, ds, tfmx=fc.noop, tfmy=fc.noop): self.ds, self.tfmx, self.tfmy=ds, tfmx, tfmy
    def __len__(self): return len(self.ds)
    def __getitem__(self, i):
        x,y = self.ds[i]
        return self.tfmx(x), self.tfmy(y)

In [25]:
id2str = (path/'wnids.txt').read_text().splitlines()
str2id = {v:k for k,v in enumerate(id2str)}
str2id

{'n02124075': 0,
 'n04067472': 1,
 'n04540053': 2,
 'n04099969': 3,
 'n07749582': 4,
 'n01641577': 5,
 'n02802426': 6,
 'n09246464': 7,
 'n07920052': 8,
 'n03970156': 9,
 'n03891332': 10,
 'n02106662': 11,
 'n03201208': 12,
 'n02279972': 13,
 'n02132136': 14,
 'n04146614': 15,
 'n07873807': 16,
 'n02364673': 17,
 'n04507155': 18,
 'n03854065': 19,
 'n03838899': 20,
 'n03733131': 21,
 'n01443537': 22,
 'n07875152': 23,
 'n03544143': 24,
 'n09428293': 25,
 'n03085013': 26,
 'n02437312': 27,
 'n07614500': 28,
 'n03804744': 29,
 'n04265275': 30,
 'n02963159': 31,
 'n02486410': 32,
 'n01944390': 33,
 'n09256479': 34,
 'n02058221': 35,
 'n04275548': 36,
 'n02321529': 37,
 'n02769748': 38,
 'n02099712': 39,
 'n07695742': 40,
 'n02056570': 41,
 'n02281406': 42,
 'n01774750': 43,
 'n02509815': 44,
 'n03983396': 45,
 'n07753592': 46,
 'n04254777': 47,
 'n02233338': 48,
 'n04008634': 49,
 'n02823428': 50,
 'n02236044': 51,
 'n03393912': 52,
 'n07583066': 53,
 'n04074963': 54,
 'n01629819': 55,
 '

In [26]:
xmean, xstd = (tensor([0.47565, 0.40303, 0.31555]), tensor([0.28858, 0.24402, 0.26615]))

In [28]:
def tfmx(x):
    img = read_image(x, mode=ImageReadMode.RGB)/255
    return (img-xmean[:,None,None])/xstd[:,None,None]

def tfmy(y): return tensor(str2id[y])

tfm_tds = TfmDS(tds, tfmx, tfmy)
tfm_vds = TfmDS(vds, tfmx, tfmy)

In [29]:
def denorm(x): return (x*xstd[:,None,None]+xmean[:,None,None]).clip(0,1)

In [32]:
all_synsets = [o.split('\t') for o in (path/'words.txt').read_text().splitlines()]
synsets = {k:v.split(',', maxsplit=1)[0] for k,v in all_synsets if k in id2str}
synsets

{'n01443537': 'goldfish',
 'n01629819': 'European fire salamander',
 'n01641577': 'bullfrog',
 'n01644900': 'tailed frog',
 'n01698640': 'American alligator',
 'n01742172': 'boa constrictor',
 'n01768244': 'trilobite',
 'n01770393': 'scorpion',
 'n01774384': 'black widow',
 'n01774750': 'tarantula',
 'n01784675': 'centipede',
 'n01855672': 'goose',
 'n01882714': 'koala',
 'n01910747': 'jellyfish',
 'n01917289': 'brain coral',
 'n01944390': 'snail',
 'n01945685': 'slug',
 'n01950731': 'sea slug',
 'n01983481': 'American lobster',
 'n01984695': 'spiny lobster',
 'n02002724': 'black stork',
 'n02056570': 'king penguin',
 'n02058221': 'albatross',
 'n02074367': 'dugong',
 'n02085620': 'Chihuahua',
 'n02094433': 'Yorkshire terrier',
 'n02099601': 'golden retriever',
 'n02099712': 'Labrador retriever',
 'n02106662': 'German shepherd',
 'n02113799': 'standard poodle',
 'n02123045': 'tabby',
 'n02123394': 'Persian cat',
 'n02124075': 'Egyptian cat',
 'n02125311': 'cougar',
 'n02129165': 'lion'

In [33]:
dls = DataLoaders(*get_dls(tfm_tds, tfm_vds, bs=bs, num_workers=8))