## COSC522: Final Project - Catifier
### Cameron Adkins, Purnachandra Anirudh Gajjala, Gabriel Abeyie

In [None]:
# Numpy.
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view

# Need plots.
import matplotlib.pyplot as plt

# Pandas.
import pandas as pd

# Machine learning toolkit.
from sklearn.model_selection import KFold, cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import *
from sklearn.preprocessing import Normalizer
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report, accuracy_score

# Scipy for fft's and the like.
import scipy as sc
import scipy.io.wavfile as wavfile
from scipy import signal
from scipy.fftpack import fft, fftfreq
from scipy import stats

# Seaborn for plots.
import seaborn as sns

# IPython for basic visual output types.
import IPython

# Imbalanced learn for rebalancing.
from imblearn.over_sampling import SMOTE

# Standard Python libs.
import os
import glob
import csv
import xml.etree.ElementTree as et
from dataclasses import dataclass

# Pillow.
import PIL as pil

# Tensorflow
#import tensorflow as tf
#from tensorflow.keras.preprocessing.image import ImageDataGenerator
#from tensorflow.keras.datasets import fashion_mnist
#from tensorflow.keras.layers import Dense
#from tensorflow.keras.optimizers import Adam
#from tensorflow.keras.layers import Conv2D, Flatten, Dense, AveragePooling2D, GlobalAveragePooling2D,Dropout
#from tensorflow.keras.applications.resnet import ResNet50
#from tensorflow.keras.models import Sequential
#from tensorflow.keras.preprocessing.image import ImageDataGenerator

# For images
#from skimage.color import rgb2gray
#import cv2
#from scipy import ndimage

In [None]:
# Utility functions.

def image_load(filename):
    loader = pil.Image.open(filename);
    ret = loader.copy();
    loader.close();
    
    return ret;

def trimap_to_mask(trimap, include_border = True):
    trimap_data = trimap.getdata();
    
    mask_data = np.zeros((trimap.height, trimap.width), dtype=np.uint8);
    
    for x in range(0, trimap.width):
        for y in range(0, trimap.height):
            idx = x + (y * trimap.width);
            tri = trimap_data[idx];
            
            if (tri == 1 or (include_border == True and tri == 3)):
                mask_data[y, x] = 255;
            else:
                mask_data[y, x] = 0;
                
    mask = pil.Image.fromarray(mask_data);
    
    return mask;

In [None]:
# Class definitions.

@dataclass
class Point:
    x: int;
    y: int;

@dataclass
class BoundingBox:
    ll: Point;
    lr: Point;
    ul: Point;
    ur: Point;

class CatBreedSample:
    def __init__(self, label, image_file, mask_file = None, bb_file = None):
        self.label = label;
        
        self.image_file = image_file;
        self.mask_file  = mask_file;
        self.bb_file    = bb_file;
        
        # Load the image
        self.image = image_load(self.image_file);
        
        # Composite if a mask is available.
        if (self.mask_file):
            self.mask = trimap_to_mask(image_load(self.mask_file));
            
            background = pil.Image.new("RGB", self.mask.size, 0);
            self.masked_image = pil.Image.composite(self.image, background, self.mask);
        else:
            self.mask = None;
            self.masked_image = None;
            
        # Get a bounding box.
        if (self.bb_file):
            tree = et.parse(self.bb_file);
            root = tree.getroot();

            xmin = int(root.findall("./object/bndbox/xmin")[0].text);
            xmax = int(root.findall("./object/bndbox/xmax")[0].text);
            ymin = int(root.findall("./object/bndbox/ymin")[0].text);
            ymax = int(root.findall("./object/bndbox/ymax")[0].text);

            self.bb = BoundingBox(0, 0, 0, 0);
            
            self.bb.ll = Point(xmin, ymin);
            self.bb.lr = Point(xmax, ymin);
            self.bb.ul = Point(xmin, ymax);
            self.bb.ur = Point(xmax, ymax);
            
            self.bounded_image = self.image.crop((xmin, ymin, xmax, ymax))
        else:
            self.bb = None;
            self.bounded_image = None;

    def display(self):
        print("Image:", self.image_file);
        print("Label:", self.label);
        
        display(self.image);
        
        if (self.mask):
            display(self.masked_image);
            
        if (self.bb):
            display(self.bounded_image);

In [None]:
# Load everything.
def load_samples(samples_dir):
    class_dirs = glob.glob(samples_dir + "/*")

    samples = [];

    for class_dir in class_dirs:
        label = os.path.basename(class_dir);
        class_dir_glob = glob.glob(class_dir + "/*.jpg");
    
        print("Reading files for label '" + label + "'");
    
        for sample_image in class_dir_glob:
            basename  = os.path.splitext(sample_image)[0];
            mask_file = basename + "_mask.png";
            bb_file   = basename + "_bb.xml";
            
            if (not os.path.exists(mask_file)):
                mask_file = None;
                
            if (not os.path.exists(bb_file)):
                bb_file = None;
            
            print(
                len(samples), 
                ":", 
                os.path.basename(sample_image), 
                "\t(mask:", 
                (mask_file != None), 
                "| bb:", 
                (bb_file != None), 
                ")"
            );
            
            sample = CatBreedSample(label, sample_image, mask_file, bb_file);
            samples.append(sample);

    return samples;

PWD = os.getcwd();
TRAINING_DATA = PWD + "/training_data";
samples = load_samples(TRAINING_DATA);

In [None]:
samples[0].display()

#test_sample = CatBreedSample(
#    "Sphynx", 
#    "/home/cva/files/devel/catifier/training_data/Sphynx/Sphynx_192.jpg",
#    "/home/cva/files/devel/catifier/training_data/Sphynx/Sphynx_192_mask.png",
#    "/home/cva/files/devel/catifier/training_data/Sphynx/Sphynx_192_bb.xml"
#)
#
#test_sample.display();