# Video Super Resolution


Data- http://toflow.csail.mit.edu/ 

### Read Data

In [141]:
import PIL,os,mimetypes
from pathlib import Path
from PIL import Image
import random 
from functools import partial 
import torch
from collections import Iterable
import matplotlib.pyplot as plt
import numpy

In [66]:
def mylist(x):
    if (x == None) : return []
    if not isinstance(x, Iterable): x=[x]
    return x

In [1]:
def _get_files(p, fs, extensions=None):
    p = Path(p)
    res = [p/f for f in fs if not f.startswith('.')
           and ((not extensions) or ("." + f.split(".")[-1].lower()) in extensions)]
    return res

In [3]:
def get_files(path, extensions=None, recurse=False, include=None):
    path = Path(path)
    extensions = {e.lower() for e in extensions}
    if recurse:
        res = []
        for i,(p,d,f) in enumerate(os.walk(str(path))): # returns (dirpath, dirnames, filenames)
            if include is not None and i==0: d[:] = [o for o in d if o in include]
            else:                            d[:] = [o for o in d if not o.startswith('.')]
            res += _get_files(p, f, extensions)
        return res
    else:
        f = [o.name for o in os.scandir(str(path)) if o.is_file()]
        return _get_files(path, f, extensions)

In [5]:
path = Path('./data')

In [6]:
Path.ls = lambda x: list(x.iterdir())

In [92]:
all_fns= get_files(path,['.jpg'])

In [25]:
import re

def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', str(text)) ]

In [155]:
def nearbyframes(file,n_frames,max_num):
    a=natural_keys(file)
    b = [a[1]+i for i in range(0,n_frames)]
    b = numpy.clip(b,0,max_num)
    filenames=[PIL.Image.open(a[0]+str(i)+a[2]) for i in b]
    return filenames
    

['data/frame1.jpg', 'data/frame1.jpg', 'data/frame1.jpg', 'data/frame1.jpg']

In [111]:
class ListContainer():
    def __init__(self, items): self.items = sorted(mylist(items),key=natural_keys)
    def __getitem__(self, idx):
        if isinstance(idx, (int,slice)): return self.items[idx]
        if isinstance(idx[0],bool):
            assert len(idx)==len(self) # bool mask
            return [o for m,o in zip(idx,self.items) if m]
        return [self.items[i] for i in idx]
    def __len__(self): return len(self.items)
    def __iter__(self): return iter(self.items)
    def __setitem__(self, i, o): self.items[i] = o
    def __delitem__(self, i): del(self.items[i])
    def __repr__(self):
        res = (str(self.__class__.__name__)+ str(len(self))+ " items\n[ ")
        if len(self) < 10:
            a = ', '.join([str(elem) for elem in self.items]) +" ]"
        else:
            a = ', '.join([str(elem) for elem in self.items[:10]]) +".....]"
        return res+a

In [112]:
def compose(x, funcs, *args, order_key='_order', **kwargs):
    key = lambda o: getattr(o, order_key, 0)
    for f in sorted(mylist(funcs), key=key): x = f(x, **kwargs)
    return x

In [167]:
def combineframes(x, tfms):
    a = [compose(i, tfms) for i in x]
    base_tensor= (a[0])[None,:,:,:]
    for i in range(len(a)-1):
        base_tensor= torch.cat((base_tensor,a[i+1][None,:,:,:]))
    return base_tensor
    

In [262]:
class ItemList(ListContainer):
    def __init__(self, items, path='.', tfms=None):
        super().__init__(items)
        self.path,self.tfms = Path(path),tfms

    def __repr__(self): return super().__repr__() + "\nPath: " + str(self.path)
    
    def new(self, items, cls=None):
        if cls is None: cls=self.__class__
        return cls(items, self.path, tfms=self.tfms)
    
    def  get(self, i): return i
    def _get(self, i):
        if isinstance(self.get(i), Iterable): 
            return combineframes(self.get(i), self.tfms)
        else:
            return self.get(i)
    
    def __getitem__(self, idx):
        res = super().__getitem__(idx)
        if isinstance(res,list): return [self._get(o) for o in res]
        return self._get(res)

class ImageList(ItemList):
    @classmethod
    def from_files(cls, path, extensions=None, recurse=False, include=None, **kwargs):
        if extensions is None: extensions = ['.jpg']
        return cls(get_files(path, extensions, recurse=recurse, include=include), path, **kwargs)
    
    def get(self, fn): return nearbyframes(fn,4,900)

In [263]:
def to_byte_tensor(item):
    res = torch.ByteTensor(torch.ByteStorage.from_buffer(item.tobytes()))
    w,h = item.size
    return res.view(h,w,-1).permute(2,0,1)
to_byte_tensor._order=20

def to_float_tensor(item): return item.float().div_(255.)
to_float_tensor._order=30

def to_resize(size,item):
    return item.resize(size, PIL.Image.BILINEAR)
to_resize._order=10

In [264]:
il = ImageList.from_files(path, tfms=tfms)

In [265]:
tfms = [partial(to_resize,(128,128)), to_byte_tensor, to_float_tensor]

In [266]:
il[0].shape

torch.Size([4, 3, 128, 128])

In [267]:
def show_image(im, ax=None, figsize=(3,3)):
    if ax is None: _,ax = plt.subplots(1, 1, figsize=figsize)
    ax.axis('off')
    ax.imshow(im.permute(1,2,0))

In [268]:
def spliter(n,f):
    if random.random() < n:
        return True
    return False

In [269]:
def split_by_func(items, f):
    mask = [f(o) for o in items]
    # `None` values will be filtered out
    f = [o for o,m in zip(items,mask) if m==False]
    t = [o for o,m in zip(items,mask) if m==True ]
    return f,t

In [270]:
class SplitData():
    def __init__(self, train, valid): self.train,self.valid = train,valid
        
    def __getattr__(self,k): return getattr(self.train,k)
    #This is needed if we want to pickle SplitData and be able to load it back without recursion errorload it back without recuload it back without recursion errorsrsion errorss
    def __setstate__(self,data): self.__dict__.update(data) 
    
    @classmethod
    def split_by_func(cls, il, f):
        lists = map(il.new, split_by_func(il.items, f))
        return cls(*lists)

    def __repr__(self): return self.__class__.__name__ +"\nTrain: " + self.train.__repr__() + "\nValid:" +self.valid.__repr__()+"\n"

In [271]:
sd = SplitData.split_by_func(il,partial(spliter, 0.1)); sd

SplitData
Train: ImageList790 items
[ data/frame0.jpg, data/frame1.jpg, data/frame2.jpg, data/frame3.jpg, data/frame4.jpg, data/frame5.jpg, data/frame6.jpg, data/frame7.jpg, data/frame8.jpg, data/frame9.jpg.....]
Path: data
Valid:ImageList111 items
[ data/frame22.jpg, data/frame28.jpg, data/frame51.jpg, data/frame52.jpg, data/frame54.jpg, data/frame68.jpg, data/frame73.jpg, data/frame74.jpg, data/frame97.jpg, data/frame100.jpg.....]
Path: data

In [288]:
### add label for your image
def parent_labeler(fn): return fn
def _label_by_func(ds, f, cls=ItemList): return cls([f(o) for o in ds.items], path=ds.path)

class LabeledData():
    def process(self, il, proc): return il.new(compose(il.items, proc))

    def __init__(self, x, y, proc_x=None, proc_y=None):
        self.x,self.y = self.process(x, proc_x),self.process(y, proc_y)
        self.proc_x,self.proc_y = proc_x,proc_y
        
    def __repr__(self): return self.__class__.__name__ +"\nx:" +self.x.__repr__() +"\n y:" +self.y.__repr__()+"\n"
    def __getitem__(self,idx): return self.x[idx],self.y[idx]
    def __len__(self): return len(self.x)
    
    def x_obj(self, idx): return self.obj(self.x, idx, self.proc_x)
    def y_obj(self, idx): return self.obj(self.y, idx, self.proc_y)
    
    def obj(self, items, idx, procs):
        isint = isinstance(idx, int) or (isinstance(idx,torch.LongTensor) and not idx.ndim)
        item = items[idx]
        for proc in reversed(mylist(procs)):
            item = proc.deproc1(item) if isint else proc.deprocess(item)
        return item

    @classmethod
    def label_by_func(cls, il, f):
        return cls(il, _label_by_func(il, f))


In [289]:
train = LabeledData.label_by_func(sd.train, parent_labeler)
valid = LabeledData.label_by_func(sd.valid, parent_labeler)

In [290]:
ll = SplitData(train,valid)

In [292]:
ll.train[0][0].shape ### 4 frames ahead of it

torch.Size([4, 3, 128, 128])