# 1.0 Import Libraries

In [1]:
import os
import shutil
import json
import h5py
import tarfile
from fastai.basics import *
from tqdm.notebook import tqdm
from itertools import islice
from model import *

# 2.0 Submission Function

In [2]:
def submission(CITIES, INPUT_DIR, OUTPUT_DIR, MODEL_DIR, MASK_DIR):
    with open(f'{INPUT_DIR}/test_slots.json', 'r') as json_file:
        test_slots = json.load(json_file)
        test_slots = {k:v for each in test_slots for k,v in each.items()}
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

    for city, MODEL_NAME in CITIES.items():

        print(city)
        input_path = Path(f'{INPUT_DIR}/{city}')
        working_path = Path(f'{OUTPUT_DIR}/{city}')
        if not os.path.exists(working_path):
            os.makedirs(working_path)

        learn_incident=False
        active_node_feat=False
        if city=='BERLIN':
            active_node_feat=True
            mask_name = f'{city}_Mask_5.pt'
            mask = torch.load(f'{MASK_DIR}/{mask_name}').float()
        if city=='ISTANBUL':
            learn_incident=True

        model = NetO(active_node_feat=active_node_feat, learn_incident=learn_incident).cuda()
        state = torch.load(f'{MODEL_DIR}/{MODEL_NAME}')
        hasopt = set(state)=={'model', 'opt'}
        model_state = state['model'] if hasopt else state
        model.load_state_dict(model_state, strict=True)
        model.eval()

        # load static features
        with h5py.File(f'{input_path}/{city}_static_2019.h5', 'r') as static_file:
            static_features = static_file.get('array')[()].astype(np.float32)
            static_features = torch.from_numpy(static_features).permute(2, 0, 1)
        static_features = static_features

        # Loop through each test date
        for date, frame in tqdm(test_slots.items()): #for date in tqdm(list(islice(test_slots, 3))):

            with h5py.File(f'{input_path}/testing/{date}_test.h5', 'r') as h5_file:
                x = h5_file.get('array')[()]

            # Note dimension reordering, from (Batch Size, Time, Height, Width, Channels) to (Batch Size, Channels, Time, Height, Width)
            x = np.transpose(x, (0, 4, 1, 2, 3))
            x = torch.from_numpy(x).float()

            ##################################################################################################
            # Calculate output
            with torch.no_grad():
                #if city=='MOSCOW':
                    N, C, D, H, W = x.shape
                    if city=='BERLIN':
                        x = x*mask
                    x = x.reshape(N, C*D, H, W)
                    active_nodes = ((x.sum(1)>0)*1.0).unsqueeze(1)
                    s = torch.stack(N*[static_features], dim=0)
                    t = torch.ones(N, 1, H, W)
                    for j in range(N):
                        t[j] = t[j] * frame[j] * 255. / (288. - 12)
                    if active_node_feat:
                        x = torch.cat([x, s, t, active_nodes], dim=1)
                    else:
                        x = torch.cat([x, s, t], dim=1)

                    if x.shape[0] > 3:
                        y1 = model(x[:3].cuda()).cpu()
                        y2 = model(x[3:].cuda()).cpu()
                        y = torch.cat([y1, y2])
                        del y1, y2
                    else:
                        y = model(x.cuda()).cpu()
                #else:
                #    y = return_mean(x)
            ##################################################################################################

            # Dimension reordering
            y = y.permute(0, 2, 3, 4, 1)
            # Assume output.shape == input.shape, hence slice out the bit required for submission
            y = y[:,[0,1,2,5,8,11],:,:,:8]

            with h5py.File(f'{working_path}/{date}_test.h5', 'w') as h5_file:
                h5_file.create_dataset('array', data=y, compression="gzip", compression_opts=6)

            del x, y
            torch.cuda.empty_cache()

            # Delete the used files to save disk space...
            #os.remove(f'{working_path}/testing/{date}_test.h5')

        # Delete data folder
        #shutil.rmtree(f'{working_path}/testing')
        #os.remove(f'{working_path}/{city}_static_2019.h5')

    # Create .zip file
    !zip -r0 submission.zip OUTPUT_DIR

# 3.0 Directory Setup

In [7]:
INPUT_DIR = '../storage'
MODEL_DIR = 'models'
MASK_DIR = 'masks'
OUTPUT_DIR = 'submission'

#MODEL_NAME = 'BERLIN_binary_g084.pth'

#CITIES = ['BERLIN', 'ISTANBUL', 'MOSCOW']
CITIES = {'MOSCOW': 'moscow.pth', 'ISTANBUL':'istanbul.pth', 'BERLIN':'berlin.pth'}

# 4.0 Run Submission

In [9]:
submission(CITIES, INPUT_DIR, OUTPUT_DIR, MODEL_DIR, MASK_DIR)