## Super resolution

In [None]:
import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.utils.mem import *

from torchvision.models import vgg16_bn

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
torch.cuda.set_device(2)

### Sizes of images: 

- HR images: (3, 1004, 1344)
- LR images: (3, 500, 669)

In [None]:
path = Path('../../../../../SCRATCH2/marvande/data/train/HR/')

path_lr = path / 'small-96/train'
path_mr = path / 'small-256/train'

path_hr = path / 'HR_patches_train/jpg_images'
#path_lr = path/'HR_patches_resized/jpg_images'

assert path.exists(), f"need dataset @ {path}"
assert path_hr.exists()

In [None]:
il = ImageList.from_folder(path_hr)
ImageList.from_folder(path_hr)

In [None]:
def resize_one(fn, i, path, size):
    dest = path/fn.relative_to(path_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn)
    targ_sz = resize_to(img, size, use_min=True)
    img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')
    img.save(dest, quality=60)

In [None]:
# create smaller image sets the first time this nb is run:
sets = [(path_lr, 96), (path_mr, 256)]
for p, size in sets:
    if not p.exists():
        print(f"resizing to {size} into {p}")
        parallel(partial(resize_one, path=p, size=size), il.items)


Creates two set of images from the HR images:
- HR: (3, 1004, 1344)
- LR: (3,96,128)
- MR: (3,256, 342)

In [None]:
ImageImageList.from_folder(path_mr)

In [None]:
ImageImageList.from_folder(path_lr)

In [None]:
# set image size and batch size to which data is transformed:
bs, size = 32, 128
arch = models.resnet34

src = ImageImageList.from_folder(path_lr).split_by_rand_pct(0.1, seed=42)
src

In [None]:
src.label_from_func(lambda x: path_hr / x.relative_to(path_lr))

In [None]:
# Remove affine transformations:
tfms = ([
    RandTransform(tfm=TfmCrop(crop_pad),
                  kwargs={
                      'row_pct': (0, 1),
                      'col_pct': (0, 1),
                      'padding_mode': 'reflection'
                  },
                  p=1.0,
                  resolved={},
                  do_run=True,
                  is_random=True,
                  use_on_y=True),
    RandTransform(tfm=TfmPixel(flip_lr),
                  kwargs={},
                  p=0.5,
                  resolved={},
                  do_run=True,
                  is_random=True,
                  use_on_y=True),
    RandTransform(tfm=TfmCoord(symmetric_warp),
                  kwargs={'magnitude': (-0.2, 0.2)},
                  p=0.75,
                  resolved={},
                  do_run=True,
                  is_random=True,
                  use_on_y=True),
    RandTransform(tfm=TfmLighting(brightness),
                  kwargs={'change': (0.4, 0.6)},
                  p=0.75,
                  resolved={},
                  do_run=True,
                  is_random=True,
                  use_on_y=True),
    RandTransform(tfm=TfmLighting(contrast),
                  kwargs={'scale': (0.8, 1.25)},
                  p=0.75,
                  resolved={},
                  do_run=True,
                  is_random=True,
                  use_on_y=True)
], [
    RandTransform(tfm=TfmCrop(crop_pad),
                  kwargs={},
                  p=1.0,
                  resolved={},
                  do_run=True,
                  is_random=True,
                  use_on_y=True)
])


# Change data to (3,128,128)

def get_data(bs, size):
    #label_from_func: apply func to every input to get its label.
    data = (src.label_from_func(
        lambda x: path_hr / x.relative_to(path_lr)).transform(
            tfms, size=size,
            tfm_y=True).databunch(bs=bs).normalize(imagenet_stats, do_y=True))
    data.c = 3
    return data

Comments: 
- here seem to normalize to imagenet_stats --> is that ok ? 


Create training data of the shape (3, 128, 128):

In [None]:
data = get_data(bs,size)
data

In [None]:
data.show_batch(ds_type=DatasetType.Valid, rows=2, figsize=(9,9))

## Feature loss

In [None]:
t = data.valid_ds[0][1].data
t = torch.stack([t,t])

In [None]:
def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)

In [None]:
gram_matrix(t)

In [None]:
base_loss = F.l1_loss

In [None]:
vgg_m = vgg16_bn(True).features.cuda().eval()
requires_grad(vgg_m, False)

In [None]:
blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]
blocks, [vgg_m[i] for i in blocks]

In [None]:
class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()

In [None]:
feat_loss = FeatureLoss(vgg_m, blocks[2:5], [5,15,2])

## Train

In [None]:
wd = 1e-3
learn = unet_learner(data, arch, wd=wd, loss_func=feat_loss, callback_fns=LossMetrics,
                     blur=True, norm_type=NormType.Weight)
gc.collect();

In [None]:
"""
# if problems with memory, restart run this, restart kernel:
torch.cuda.empty_cache()
gpu_mem_get_free_no_cache
learn.destroy()
"""

In [None]:
def memReport():
    for obj in gc.get_objects():
        if torch.is_tensor(obj):
            print(type(obj), obj.size())
    
def cpuStats():
        print(sys.version)
        print(psutil.cpu_percent())
        print(psutil.virtual_memory())  # physical memory usage
        pid = os.getpid()
        py = psutil.Process(pid)
        memoryUse = py.memory_info()[0] / 2. ** 30  # memory use in GB...I think
        print('memory GB:', memoryUse)

#cpuStats()
#memReport()

In [None]:
torch.cuda.device_count()
torch.cuda.current_stream()
torch.cuda.is_available()

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
lr = 1e-3

In [None]:
def do_fit(save_name, lrs=slice(lr), pct_start=0.9):
    learn.fit_one_cycle(10, lrs, pct_start=pct_start)
    learn.save(save_name)
    learn.show_results(rows=1, imgsize=5)

In [None]:
do_fit('1a', slice(lr*10))

In [None]:
learn.unfreeze()

In [None]:
do_fit('1b', slice(1e-5,lr))

In [None]:
data = get_data(12,size*2)

In [None]:
learn.data = data
learn.freeze()
gc.collect()

In [None]:
data

In [None]:
learn.load('1b');

In [None]:
do_fit('2a')

In [None]:
learn.unfreeze()

In [None]:
do_fit('2b', slice(1e-6,1e-4), pct_start=0.3)

## Test

In [None]:
learn = None
gc.collect();

In [None]:
256/320*1024

In [None]:
256/320*1600

In [None]:
free = gpu_mem_get_free_no_cache()
# the max size of the test image depends on the available GPU RAM 
if free > 8000: size=(1280, 1600) # >  8GB RAM
else:           size=( 820, 1024) # <= 8GB RAM
print(f"using size={size}, have {free}MB of GPU RAM free")

In [None]:
learn = unet_learner(data, arch, loss_func=F.l1_loss, blur=True, norm_type=NormType.Weight)

In [None]:
data_mr = (ImageImageList.from_folder(path_mr).split_by_rand_pct(0.1, seed=42)
          .label_from_func(lambda x: path_hr/x.name)
          .transform(tfms, size=size, tfm_y=True)
          .databunch(bs=1).normalize(imagenet_stats, do_y=True))
data_mr.c = 3

In [None]:
data_mr

In [None]:
learn.load('2b');

In [None]:
learn.data = data_mr

In [None]:
fn = data_mr.valid_ds.x.items[0]; fn

In [None]:
img = open_image(fn); img.shape

In [None]:
p,img_hr,b = learn.predict(img)

In [None]:
show_image(img, figsize=(18,15), interpolation='nearest');

In [None]:
Image(img_hr).show(figsize=(18,15))