In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
import os
import logging as log
from time import strftime
from copy import deepcopy
from torch import nn, optim
import torch.nn.functional as F
from utils.data_processing import *
from logger.logger import setup_logging
from utils.configs import BaseConf
from utils.metrics import best_threshold
from utils.utils import write_json, Timer
from dataloaders.grid_loader import GridDataLoaders
from datasets.grid_dataset import GridDataGroup
from utils.metrics import PRCurvePlotter, ROCCurvePlotter, LossPlotter, best_threshold, get_y_pred, \
                                get_y_pred_by_thresholds, best_thresholds
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score
from models.model_result import ModelResult, ModelMetrics
from trainers.generic_trainer import train_model
from utils.plots import im
from utils.utils import pshape, get_data_sub_paths, by_ref
from models.model_result import save_metrics,save_results, compare_models, compare_all_models, get_metrics_table, \
                            get_models_metrics, get_models_results
from models.st_resnet_models import STResNet, STResNetExtra
from models.st_resnet_models import train_epoch_for_st_res_net, train_epoch_for_st_res_net_extra
from models.st_resnet_models import evaluate_st_res_net, evaluate_st_res_net_extra
import pandas as pd
from pprint import pprint
from torch.optim import lr_scheduler
pd.set_option('display.max_columns', None)

In [2]:
data_sub_paths = get_data_sub_paths()
pprint(sorted(data_sub_paths))

['T1H-X3400M-Y3520M_2014-01-01_2016-01-01_#7cd',
 'T24H-X255M-Y220M_2012-01-01_2019-01-01_#c97',
 'T24H-X425M-Y440M_2012-01-01_2019-01-01_#827',
 'T24H-X850M-Y880M_2012-01-01_2019-01-01_#826']


In [3]:
data_sub_path = by_ref("c97")[0]
print(f"using: {data_sub_path}")

using: T24H-X255M-Y220M_2012-01-01_2019-01-01_#c97


In [4]:
conf = BaseConf()

conf.model_name = "explore-data"  # needs to be created

conf.data_path = f"./data/processed/{data_sub_path}/"

if not os.path.exists(conf.data_path):
    raise Exception(f"Directory ({conf.data_path}) needs to exist.")

conf.model_path =  f"{conf.data_path}models/{conf.model_name}/"
os.makedirs(conf.data_path, exist_ok=True)
os.makedirs(conf.model_path, exist_ok=True)

# logging config is set globally thus we only need to call this in this file
# imported function logs will follow the configuration
setup_logging(save_dir=conf.model_path, log_config='./logger/standard_logger_config.json', default_level=log.INFO)
log.info("=====================================BEGIN=====================================")

info = deepcopy(conf.__dict__)
info["start_time"] = strftime("%Y-%m-%dT%H:%M:%S")

# DATA LOADER SETUP
np.random.seed(conf.seed)
use_cuda = False # torch.cuda.is_available()
torch.manual_seed(conf.seed)
if use_cuda:
    torch.cuda.manual_seed(conf.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

device = torch.device("cuda:0" if use_cuda else "cpu")
log.info(f"Device: {device}")
info["device"] = device.type
conf.device = device

conf.batch_size = 128

# CRIME DATA
data_group = GridDataGroup(data_path=conf.data_path,
                           conf=conf)

loaders = GridDataLoaders(data_group=data_group,
                          conf=conf)

2020-10-23T12:36:16 | root | INFO | Device: cpu
2020-10-23T12:36:16 | root | INFO | Initialising Grid Data Group
2020-10-23T12:36:16 | root | INFO | 	t_range: (2558,) 2012-01-01 00:00:00 -> 2019-01-01 00:00:00
2020-10-23T12:36:16 | root | INFO | 	target_len:	2515	(100.000%)
2020-10-23T12:36:16 | root | INFO | 	trn_val_size:	2155	(85.686%)
2020-10-23T12:36:16 | root | INFO | 	trn_size:	1617	(64.294%)
2020-10-23T12:36:16 | root | INFO | 	val_size:	538	(21.392%)
2020-10-23T12:36:16 | root | INFO | 	tst_size:	360 	(14.314%)


In [5]:
from utils.interactive import InteractiveHeatmaps

InteractiveHeatmaps(
    date_range=data_group.t_range, 
    col_wrap=1,
    Counts=data_group.crimes[:,0],
).app

VBox(children=(Label(value='Date: Mon Jan  2 00:00:00 2012'), HBox(children=(Play(value=0, description='Press …

In [6]:
grids = data_group.to_counts(sparse_data=data_group.crimes)

In [7]:
from utils.interactive import interactive_grid_visualiser

interactive_grid_visualiser(
    grids=grids[:,0],#np.flipud(grids[:,0]), 
    t_range=data_group.t_range, 
    mutual_info=True,
    max_offset=365,
)


Box(children=(FigureWidget({
    'data': [{'type': 'heatmap',
              'uid': 'cc3523be-6d41-4bef-892d-1e…