In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd '/content/drive/MyDrive/Subjects/Digital Video Processing/Final Project'
%run 'config.ipynb'

/content/drive/MyDrive/Subjects/Digital Video Processing/Final Project


In [None]:
import tensorflow as tens
import tensorflow.keras.backend as back
import math
import random

auto = tens.data.experimental.AUTOTUNE

def decode(img, h, w):
    img2 = tens.image.decode_jpeg(img, channels=3)
    img2 = tens.cast(img2, tens.float32) / 255.0
    img2 = tens.reshape(img2, [h,w, -1])
    return img2

def read_tfrecord(e):
    tfrec = {
        "img1": tens.io.FixedLenFeature([], tens.string),
        "img2": tens.io.FixedLenFeature([], tens.string),
        "img3": tens.io.FixedLenFeature([], tens.string),
        'height': tens.io.FixedLenFeature([], tens.int64),
        'width': tens.io.FixedLenFeature([], tens.int64)
    }
    e = tens.io.parse_single_example(e, tfrec)
    h = tens.cast(e['height'], tens.int32)
    w = tens.cast(e['width'], tens.int32)
    img1 = decode(e['img1'],h,w)
    img2 = decode(e['img2'],h,w)
    img3 = decode(e['img3'],h,w)

    return img1,img2,img3,h,w

def get_matrix(rot, shear, hz, wz, hs, ws):
    rot = math.pi * rot / 180.
    shear = math.pi * shear / 180.

    c1 = tens.math.cos(rot)
    s1 = tens.math.sin(rot)

    c2 = tens.math.cos(shear)
    s2 = tens.math.sin(shear)

    one = tens.constant([1],dtype='float32')
    zero = tens.constant([0],dtype='float32')

    rot_matrix = tens.reshape( tens.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )
    shear_matrix = tens.reshape( tens.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3] )
    zoom_matrix = tens.reshape( tens.concat([one/hz,zero,zero, zero,one/wz,zero, zero,zero,one],axis=0),[3,3] )
    shift_matrix = tens.reshape( tens.concat([one,zero,hs, zero,one,ws, zero,zero,one],axis=0),[3,3] )

    return back.dot(back.dot(rot_matrix, shear_matrix), back.dot(zoom_matrix, shift_matrix))

def transform(img,seed,sign = 0):
    dim = PATCH_SIZE[0]
    xdim = dim%2

    rot = 15. * tens.random.normal([1],dtype='float32',seed=seed)
    shr = 5. * tens.random.normal([1],dtype='float32',seed=seed)
    hz = 1.0
    wz = 1.0
    hs = 6. * tens.random.normal([1],dtype='float32',seed=seed)*sign
    ws = 6. * tens.random.normal([1],dtype='float32',seed=seed)*sign

    m = get_matrix(rot,shr,hz,wz,hs,ws)
    x = tens.repeat( tens.range(dim//2,-dim//2,-1), dim )
    y = tens.tile( tens.range(-dim//2,dim//2),[dim] )
    z = tens.ones([dim*dim],dtype='int32')

    idx = tens.stack( [x,y,z] )
    idx2 = back.dot(m,tens.cast(idx,dtype='float32'))
    idx2 = back.cast(idx2,dtype='int32')
    idx2 = back.clip(idx2,-dim//2+xdim+1,dim//2)
    idx3 = tens.stack( [dim//2-idx2[0,], dim//2-1+idx2[1,]] )

    d = tens.gather_nd(img, tens.transpose(idx3))

    return tens.reshape(d,[dim,dim,3])

def augment_data(img1,img2,img3, x, y):
  seed = random.randint(0,1000)

  img1 = transform(img1,seed,sign = 1)
  img2 = transform(img2,seed)
  img3 = transform(img3,seed,sign = -1)

  fraction = CROP_SIZE_fraction

  img1 = tens.image.central_crop(img1, central_fraction = fraction)
  img3 = tens.image.central_crop(img3, central_fraction = fraction)
  img2 = tens.image.central_crop(img2, central_fraction = fraction)
  return tens.concat([img1,img3],axis = -1),img2

def load_data(files):
    ignore = tens.data.Options()
    ignore.experimental_deterministic = False
    data = tens.data.TFRecordDataset(files, num_parallel_reads=auto)
    data = data.with_options(ignore)
    data = data.map(read_tfrecord , num_parallel_calls=auto)
    return data

def get_training_data(files):
    data = load_data(files)
    data = data.map(augment_data)
    data = data.repeat()
    data = data.shuffle(256)
    data = data.batch(BATCH_SIZE)
    data = data.prefetch(auto)
    return data