In [2]:
import numpy as np
import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
import comms
import ipywidgets as ipw
import keras 
import pickle
import vgg

from keras.models import Sequential 
from keras.layers import Dense, Dropout, Flatten, Input
from keras.layers import Conv2D, MaxPooling2D
from keras.optimizers import SGD
from keras.regularizers import l2
from keras.preprocessing.image import ImageDataGenerator

## Load the data 

In [3]:
data_train = comms.load_data("train.json/data/processed/train.json")

In [4]:
data_train = data_train.sample(frac = 1)

In [5]:
params_list = [ 
    [[2,2,2], [32, 64, 128], [256], [0.5]], 
    [[3,3,3], [32, 64, 128], [256], [0.5]],
    [[4,4,4], [32, 64, 128], [256], [0.5]],
    [[3,3], [32, 64], [256, 256], [0.5, 0.5]],
    [[3,3], [32, 64], [128, 64], [0.5, 0.5]],
    [[2,2], [32, 64], [128], [0.5]],
    [[2,2], [32, 32], [64], [0.5]],
    [[2,2,2,2,2], [16, 32, 64, 128, 256], [512], [0.5]]
] 

In [6]:
histories = []

## Train the model 

In [7]:
img_train, ang_train = comms.get_img_and_angle(data_train)
is_ice_train = comms.get_is_iceberg(data_train)

In [8]:
for params in params_list:
    model = vgg.create_vgg_simple(*params)
    history = model.fit(x = img_train, 
                        y = is_ice_train, 
                        validation_split = 0.1, 
                        epochs = 20)
    histories.append(history)

Train on 1443 samples, validate on 161 samples
Epoch 1/20
 288/1443 [====>.........................] - ETA: 47s - loss: 1.1330

KeyboardInterrupt: 

In [None]:
pickle.dump(histories, open("histories.pkl", "w"))

## Make predictions for the test data 

In [None]:
data_test = comms.load_data("test.json/data/processed/test.json")

In [None]:
img_test, ang_test = comms.get_img_and_angle(data_test)

In [None]:
is_ice_test = model.predict(img_test)

In [None]:
id_test = np.array(data_test.id)

In [None]:
res = pd.DataFrame({"id" : id_test.flatten(), "is_iceberg" : is_ice_test.flatten()})
res.to_csv("result.csv", sep=",", index=False)