# 2D model without preprocessing

In [None]:
# 3rd-party imports
import numpy as np
import nibabel as nib
import tensorflow as tf
import tensorflow.keras as k
import matplotlib.pyplot as plt

In [None]:
import os
if not os.path.isdir("../notebooks") and not os.path.isdir("imed-project"):
  !git clone https://github.com/nicomem/imed-project.git
  %cd imed-project/notebooks

## Installing Requirements

In [None]:
%cd ..
!pip install -r requirements.txt
!pip install gdown

In [None]:
import gdown
if not os.path.exists("data.zip"):
  gdown.download("https://drive.google.com/uc?id=1onHHWIhkhN5xYMit0rhhtVXlJrAlzCit", "data.zip", quiet=False)

In [None]:
if not os.path.isdir("data"):
  !unzip data.zip

In [None]:
%cd notebooks/
!ls

## Get dataset & split train/test

In [None]:
from utils.load_data import get_dataset, NibDataSequence, CachedDataSequence

train_nib, val_nib = get_dataset('../data', verbose=True)

In [None]:
train_nib.keys()

In [None]:
[len(v) for v in train_nib.values()]

In [None]:
[len(v) for v in val_nib.values()]

## Load train & analyze

In [None]:
train_seq_uncached = NibDataSequence(train_nib)
train_seq = CachedDataSequence(train_seq_uncached)
len(train_seq)

In [None]:
# Different number of slices & X/Y dimensions for inputs & targets
train_seq.X[0].shape, train_seq.Y[0].shape

In [None]:
train_seq.X[-1].shape, train_seq.Y[-1].shape

In [None]:
i_data = 20
i_slice = 25

plt.figure(figsize=(10, 10))
plt.subplot(1, 3, 1)
plt.imshow(train_seq.X[i_data][i_slice,...,0])

plt.subplot(1, 3, 2)
plt.imshow(train_seq.X[i_data][i_slice,...,1])

plt.subplot(1, 3, 3)
plt.imshow(train_seq.Y[i_data][i_slice])

## Prepare the model

In [None]:
# TODO

# Dummy model
inputs = k.Input((None, None, 2))
conv1 = k.layers.Conv2D(2, 5, activation='relu', padding='same')(inputs)
outputs = k.layers.Conv2D(1, 3, activation='sigmoid', padding='same')(conv1)

model = k.Model(inputs, outputs)
model.summary()

## Train the model

In [None]:
val_seq_uncached = NibDataSequence(val_nib)
val_seq = CachedDataSequence(val_seq_uncached)
len(val_seq)

In [None]:
model.compile(
    optimizer='adam',
    loss='binary_crossentropy'
)

In [None]:
history = model.fit(train_seq, epochs=1, validation_data=val_seq)

## Check the results

In [None]:
# TODO