# Feature Loss

In [0]:
from fastai import *
from fastai.tabular import *
import pandas as pd
from torchsummary import summary
import torch
from torch import nn
import imageio
import torch
import glob
from fastai.vision import *
import os
from torch import nn
import torch.nn.functional as F
from torchvision.models import vgg16_bn

## Data and Imports

In [0]:
colab = True
if colab:
  from google.colab import drive
  drive.mount('/content/drive', force_remount = True)
  %cp "/content/drive/My Drive/autoencoder-training/model_layers.py" .
  %cp "/content/drive/My Drive/autoencoder-training/baseline_model.py" .
  %cp "/content/drive/My Drive/autoencoder-training/featureLoss_function.py" .
  import baseline_model
  import featureLoss_function
else: 
  os.chdir("../")
  image_path = os.getcwd() + "/data"
  from resnet_autoencoder_training import baseline_model
  from resnet_autoencoder_training import featureLoss_function

In [0]:
np.random.seed(3333)
torch.manual_seed(3333)

size = 32
batchsize = 128

path = untar_data(URLs.MNIST)
tfms = get_transforms(do_flip=True, flip_vert=True, max_rotate=10, max_zoom=1.1, max_lighting=0.2, max_warp=0.2, p_affine=0, 
                      p_lighting=0.75)
src = (ImageImageList.from_folder(path).split_by_folder("training", "testing").label_from_func(lambda x: x))
data = (src.transform(tfms, size=size, tfm_y=True).databunch(bs=batchsize).normalize(imagenet_stats, do_y = False))

## Model

In [0]:
vgg_m = vgg16_bn(True).features.cuda().eval()

requires_grad(vgg_m, False)
blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]

base_loss = F.mse_loss

feat_loss = featureLoss_function.FeatureLoss(vgg_m, blocks[0:3], [30,20,10], base_loss)

In [0]:
autoencoder = baseline_model.autoencoder()
learn = Learner(data, autoencoder, loss_func = feat_loss, metrics = [mean_squared_error, mean_absolute_error])

In [0]:
learn.fit_one_cycle(5)

In [0]:
learn.metrics = [mean_squared_error, mean_absolute_error, r2_score, explained_variance]

In [0]:
learn.lr_find()

In [0]:
learn.recorder.plot(suggestion=True)

In [0]:
learn.fit_one_cycle(10, max_lr=1e-03)

In [0]:
learn.show_results(ds_type=DatasetType.Train)

In [0]:
learn.show_results(ds_type=DatasetType.Valid)