# **Experiment with Training on a Pre-Trained Imagenet Model**

One interesting exercise will be to apply Stable Diffusion using a pretrained model. Jeremy demostrated the application of ImageNet for this task. Recall that during the super resolution segment of the course, there was a huge difference in the performance of the pre-trained model compared to the newly trained model - with the prior clearly showing improved outputs compared to the latter. 

The same thinking can be applied here. However, we will need a pretrained latents model where the downsampling layers are pretrained on latents. A full ImageNet model, pretrained on latents should be up to the task.

We can grab the full Imagenet model from Kaggle under the [ImageNet Object Localization Challenge](https://www.kaggle.com/c/imagenet-object-localization-challenge/overview).

In [None]:
import os
#os.environ['CUDA_VISIBLE_DEVICES']='0'
os.environ['OMP_NUM_THREADS']='1'

In [None]:
import pickle,gzip

from glob import glob
from torcheval.metrics import MulticlassAccuracy
from fastprogress import progress_bar
from diffusers import AutoencoderKL

from miniai.imports import *

In [None]:
torch.set_printoptions(precision=5, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['figure.dpi'] = 70

set_seed(42)
if fc.defaults.cpus>8: fc.defaults.cpus=8

We will download images from Kaggle to the ImageNet Large Scale Visual Recognition Challenge (ILSVRC) folder.

In [None]:
path_data = Path('data')/'ILSVRC'
path = path_data/'Data'/'CLS-LOC'

dest = path_data/'latents'
dest.mkdir(exist_ok=True)

In [None]:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").cuda().requires_grad_(False)

In [None]:
class ImageDS:
    def __init__(self, path, spec):
        cache = path/'files.zpkl'
        if cache.exists():
            with gzip.open(cache) as f: self.files = pickle.load(f)
        else:
            self.files = glob(str(path/spec), recursive=True)
            with gzip.open(cache, 'wb', compresslevel=1) as f: pickle.dump(self.files, f)

    def __len__(self): return len(self.files)

    def __getitem__(self, i):
        f = self.files[i]
        im = read_image(f, mode=ImageReadMode.RGB) / 255
        im = TF.resize(TF.center_crop(im, min(im.shape[1:])), 256)
        return im, f

In [None]:
ds = ImageDS(path, '**/*.JPEG')
dl = DataLoader(ds, batch_size=64, num_workers=fc.defaults.cpus)

In [None]:
xb,yb = next(iter(dl))
xe = vae.encode(xb.cuda())
xs = xe.latent_dist.mean
xs.shape

In [None]:
show_images(((xs[:16,:3])/4).sigmoid(), imsize=2)

In [None]:
xd = to_cpu(vae.decode(xs))
show_images(xd['sample'][:16].clamp(0,1), imsize=2)

In [None]:
if not dest.exists():
    dest.mkdir()
    for xb, yb in progress_bar(dl):
        eb = to_cpu(vae.encode(xb.cuda()).latent_dist.mean).numpy()
        for ebi, ybi in zip(eb, yb):
            ybi = dest/Path(ybi).relative_to(path).with_suffix('')
            (ybi.parent).mkdir(parents=True, exist_ok=True)
            np.save(ybi, ebi)

In [None]:
class NumpyDS(ImageDS):
    def __getitem__(self, i):
        f = self.files[i]
        im = np.load(f)
        return im, f

In [None]:
bs = 128

In [None]:
tds = NumpyDS(dest/'train', '**/*.npy')
vds = NumpyDS(dest/'val', '**/*.npy')

In [None]:
tdl = DataLoader(tds, batch_size=bs, num_workers=0)
xb,yb = next(iter(tdl))

xb.mean((0,2,3)), xb.std((0,2,3))

In [None]:
xmean, xstd = (tensor([5.37007, 2.65468, 0.44876, -2.39154]),
               tensor([3.99512, 4.44317, 3.21629, 3.10339]))

In [None]:
class TfmDS:
    def __init__(self, ds, tfmx=fc.noop, tfmy=fc.noop): self.ds, self.tfmx, self.tfmy = ds, tfmx, tfmy

    def __len__(self): return len(self.ds)

    def __getitem__(self, i):
        x, y = self.ds[i]
        return self.tfmx(x), self.tfmy(y)

In [None]:
id2str = (path_data/'imagenet_lsvrc_2015_synsets.txt').read_text().splitlines()
str2id = {v:k for k,v in enumerate(id2str)}

In [None]:
aug_tfms = nn.Sequential(T.Pad(2), T.RandomCrop(32), RandErase())
norm_tfm = T.Normalize(xmean, xstd)

In [None]:
def tfmx(x, aug=False):
    x = norm_tfm(tensor(x))
    if aug: x = aug_tfms(x[None])[0]
    return x

def tfmy(y): return tensor(str2id[Path(y).parent.name])

tfm_tds = TfmDS(tds, partial(tfmx, aug=True), tfmy)
tfm_vds = TfmDS(vds, tfmx, tfmy)

In [None]:
def denorm(x): return (x*xstd[:,None,None]+xmean[:,None,None])

In [None]:
dls = DataLoaders(*get_dls(tfm_tds, tfm_vds, bs=bs, num_workers=8))

In [None]:
all_synsets = [o.split('\t') for o in (path_data/'words.txt').read_text().splitlines()]
synsets = {k:v.split(',', maxsplit=1)[0] for k,v in all_synsets if k in id2str}

In [None]:
xb,yb = next(iter(dls.train))
titles = [synsets[id2str[o]] for o in yb]
xb.mean(),xb.std()

In [None]:
xd = to_cpu(vae.decode(denorm(xb[:9]).cuda()))
show_images(xd['sample'].clamp(0, 1), imsize=4, titles=titles[:9])