In [25]:
import matplotlib as plt
import numpy as np
import argparse
import os
from tqdm.notebook import tqdm as tqdm
import jax
import jax.numpy as jnp
from jax.experimental.optimizers import adam
from jax.experimental.stax import *
from jax import random, jit
import numpy as np
import glob
from skimage.transform import resize
from livelossplot import PlotLosses

from utils.img_ops import *
from utils import *

from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ColorJitter, Resize, ToTensor

## Data Loading

In [21]:
class ImmFaceDb(Dataset):
    def __init__(self, root_dir, keypoints=[-6], transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
        """
        self.root_dir = root_dir
        img_paths = glob.glob(root_dir + '/*.png')
        img_paths.sort()
        self.img_paths = np.array(img_paths)

        asf_paths = glob.glob(root_dir + '/*.asf')
        asf_paths.sort()
        self.asf_paths = np.array(asf_paths)

        self.keypoints = keypoints
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_paths = self.img_paths[idx]
        asf_paths = self.asf_paths[idx]
        
        imgs = np.array([read(i, color=False) for i in img_paths])
        imgs = imgs - 0.5
        asf_data = [self.read_asf(f, self.keypoints) for f in asf_paths]
        asf_data = np.row_stack([np.column_stack((x[0], x[1])) for x in asf_data])
        sample = {'x': imgs, 'y': asf_data}

        if self.transform:
            sample = self.transform(sample)

        return sample

    def read_asf(self, file, keypoints=None):
        """Reads x, y points from asf file

        Args:
            file (str): path to asf file

        Returns:
            np.ndarray, np.ndarray: list of x points, list of y points
        """
        if keypoints:
            data = np.genfromtxt(file, skip_header=16, skip_footer=1, usecols=(2, 3))[keypoints, :]
        else:
            data = np.genfromtxt(file, skip_header=16, skip_footer=1, usecols=(2, 3))[:, :]
        return data[:, 0], data[:, 1]


In [23]:
class ColorJitterX(object):
    def __call__(self, sample):
        image, points = sample['x'], sample['y']
        plt.imshow(image)
        plt.show()
        brightness = np.random.random() * 2 - 1
        contrast = np.random.random() * 2 - 1
        image = image * contrast + brightness
        plt.imshow(image)
        plt.show()
        return {'x': image, 'y': points}

In [24]:
batch_size = 8
transforms = Compose([
    Resize((240, 180)),
    ColorJitterX(),
])
keypoints_dataset = ImmFaceDb('imm_face_db/', keypoints=None, transform=transforms)
data_loader = DataLoader(keypoints_dataset, batch_size=batch_size, shuffle=False)
print(list(data_loader))
for i_batch, sample_batched in enumerate(data_loader):
    print('hello')
    plt.imshow(sample_batched['x'][0])
    plt.show()

[]
[]


In [None]:
# Load Data
val_split = (32, 8)
train_imgs, train_points, val_imgs, val_points = load_data(
    val_split=val_split, keypoints=[-6]
)
img = train_imgs[0]
h, w, c = img.shape[0], img.shape[1], img.shape[2]

In [None]:
model_init_fn, model = serial(
    Conv(32, (3, 3), padding="SAME"),
    MaxPool((2, 2)),
    Relu,
    Conv(32, (3, 3), padding="SAME"),
    MaxPool((2, 2)),
    Relu,
    Conv(32, (3, 3), padding="SAME"),
    MaxPool((2, 2)),
    Relu,
    Flatten, 
    Dense(256),
    Relu,
    Dense(2)
)

In [None]:
# Loss -- MSE
@jit
def loss_fn(params, imgs, gt):
    pred = model(params, imgs)
    return np.mean((pred - gt) ** 2)

@jit
def update(step, opt_state, imgs, gt):
    value, grads = jax.value_and_grad(loss_fn)(opt.params_fn(opt_state), imgs, gt)
    opt_state = opt.update_fn(step, grads, opt_state)
    return value, opt_state

In [None]:
# Initialize Model
batch_size = 16
rng = random.PRNGKey(32)
input_shape, params = model_init_fn(rng, (batch_size, h, w, c))

# Optimizer
lr = 1e-3
opt = adam(lr)
opt_state = opt.init_fn(params)

# Create Plots
plt_groups = {'loss':[]}
plotlosses_model = PlotLosses(groups=plt_groups)
plt_groups['loss'].append('nose_pred_train')
plt_groups['loss'].append('nose_pred_val')

# Training Loop
epochs = 5
steps = int(len(train_points) // batch_size)
val_ratio = int(len(train_points) // len(val_points))

train_loss = np.ndarray(steps * epochs)
val_loss = np.ndarray(steps * epochs // val_ratio)

iters = 0
save_params = params
min_val = 10000
for j in tqdm(range(epochs), leave=False, desc='iter'):
    for i in range(steps):
        value, opt_state = update(
            i,
            opt_state,
            train_imgs[i : i + batch_size],
            train_points[i : i + batch_size],
        )

        plotlosses_model.update({'nose_pred_train':value}, current_step=iters)

        # get validation loss
        if i % val_ratio == 0:
            val_value, grads = jax.value_and_grad(loss_fn)(
                opt.params_fn(opt_state),
                val_imgs,
                val_points,
            )
            plotlosses_model.update({'nose_pred_val':val_value}, current_step=iters)
            plotlosses_model.send()
            
            if val_value <= min_val:
                save_params = opt.params_fn(opt_state)
                min_val = val_value
                
        iters += 1

In [None]:
for i in range(3):
    pred = model(save_params, val_imgs[i][None,...])
    print(pred, val_points[i])
    plt.imshow(val_imgs[i] + 0.5, cmap='gray')
    plt.scatter(val_points[i][0] * w, val_points[i][1] * h)
    plt.scatter(pred[i][0] * w, pred[i][1] * h)
    plt.show()

In [None]:
for i in range(3):
    pred = model(opt.params_fn(opt_state), train_imgs[i][None,...])
    print(pred, val_points[i])
    plt.imshow(train_imgs[i] + 0.5, cmap='gray')
    plt.scatter(train_points[i][0] * w, train_points[i][1] * h)
    plt.show()

In [None]:
plotlosses_model.send()