In [2]:
import argparse
import logging
import os

import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from torch.autograd import Variable
from tqdm import tqdm
from train import train, train_and_evaluate
import matplotlib.pyplot as plt
from tifffile import imsave, TiffFile
from pathlib import Path

import utils
import model.net as net
from  model.data_loader import  fetch_dataloader
from evaluate import *

%load_ext autoreload
%autoreload 2
%matplotlib inline

## Initial training

Let's start with a simple model:

- no augmentation (that is critical with such a small dataset);
- default learning rate;
- default optimizer (in the paper they use high momentum for a `batch_size = 1`); 

and so on.

### `params`

In [3]:
model_dir = 'experiments/initial_model'
json_path = os.path.join(model_dir, 'params.json')
params = utils.Params(json_path)

In [4]:
params.dict

{'learning_rate': 0.001,
 'batch_size': 1,
 'num_epochs': 10,
 'save_summary_steps': 100,
 'num_workers': 0}

### training

In [None]:
!python3 train.py --model_dir 'experiments/initial_model'

Loading the datasets...
- done.
Starting training for 10 epoch(s)
Epoch 1/10
100%|███████████████████████████████| 25/25 [04:33<00:00, 11.54s/it, loss=0.483]
- Train metrics: accuracy: 0.274 ; loss: 0.782
- Eval metrics : accuracy: 0.804 ; loss: 0.490
Checkpoint Directory exists! 
- Found new best accuracy
Epoch 2/10
 80%|████████████████████████▊      | 20/25 [03:19<00:48,  9.78s/it, loss=0.370]

### plots

Let's now plot accuracy and loss.

In [None]:
history_path = os.path.join(model_dir, 'history.csv')

In [None]:
df = pd.read_csv(history_path)

In [None]:
df[['train_acc', 'val_acc']].plot();

In [None]:
df[['train_loss', 'val_loss']].plot();

### prediction

#### get data and model

In [None]:
params.cuda = torch.cuda.is_available()
dataloaders = fetch_dataloader(['train', 'test'], params)
train_dataloader = dataloaders['train'] 
test_dataloader = dataloaders['test']

In [None]:
image, target = next(iter(train_dataloader))

In [None]:
len(test_dataloader)

In [None]:
model = net.Unet().cuda() if params.cuda else net.Unet()
checkpoint = os.path.join(model_dir, 'best.pth.tar')
utils.load_checkpoint(checkpoint, model);

#### get prediction

In [None]:
prediction = predict_image(model, image, use_thresh=False)

In [None]:
prediction.shape

In [None]:
prediction

In [None]:
plt.imshow(prediction.reshape(512, 512), cmap='gray')