# A demo for model predictions based on the transfer-learnt model

## Load essential packages

In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# import packages and functions

import os
import sys
import glob
import numpy as np
import pandas as pd

import torch
from torch import optim, cuda
from torch import nn
from torch.functional import F
torch.backends.cudnn.benchmark = False
from torchsummary import summary

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

from dataset import BrainDataset
from load_save_checkpoint import load_checkpoint, save_checkpoint

## Configure device

In [None]:
batch_size = 32

## Whether to train on a gpu
train_on_gpu = cuda.is_available()
print(f'Train on gpu: {train_on_gpu}')

## Number of gpus
if train_on_gpu:
    gpu_count = cuda.device_count()
    print(f'{gpu_count} gpus detected.')
    if gpu_count > 1:
        multi_gpu = True
    else:
        multi_gpu = False
else:
    multi_gpu = False

## Load data and model

### Load data

In [None]:
file_list = ['./119833_Motor_1_lh.nii.gz']
label_list = [[0,0,0,1,0,0,0]]

dataset = BrainDataset(file_list, label_list, is_train=False)

### Load model

In [None]:
checkpoint_path_bestloss = './checkpoint/3dconv-transfer_checkpoint_bestloss_v6.pth'
model, optimizer = load_checkpoint(checkpoint_path_bestloss, train_on_gpu, multi_gpu)

#### a summary of the model below:

In [None]:
if train_on_gpu:
    if multi_gpu:
        summary(
            model.module,
            input_size=(27, 75, 93, 81),   # the input_size needs to be updated!!!
            batch_size=batch_size,
            device='cuda')
    else:
        summary(
            model, input_size=(27, 75, 93, 81), batch_size=batch_size, device='cuda')  # the input_size needs to be updated!!!
else:
    summary(
        model, input_size=(27, 75, 93, 81), batch_size=batch_size, device='cpu')  # the input_size needs to be updated!!!

## Generate a prediction from an input image

In [None]:
_ = model.eval()
with torch.no_grad():
    inputs, label = dataset[0]
    if train_on_gpu:
        pred_prob = F.softmax(model(torch.FloatTensor(inputs).unsqueeze(0).to(device = 'cuda')))
    else:
        pred_prob = F.softmax(model(torch.FloatTensor(inputs).unsqueeze(0).to(device = 'cpu')))
pred_prob