In [None]:
import os
import shutil
import json
import h5py
import tarfile
from fastai.basics import *
from tqdm.notebook import tqdm

from model import *

WORKING_DIR = '../working'
INPUT_DIR = '../input/traffic4cast2020'
MODEL_DIR = '../input/traffic4cast2020models'
MODEL_NAME = 'moscow_o070_m.pth'

CITIES = ['MOSCOW']

# Load test_slots.json

In [None]:
with open(f'{INPUT_DIR}/test_slots.json', 'r') as json_file:
    test_slots = json.load(json_file)
    test_slots = {k:v for each in test_slots for k,v in each.items()}

# Load Model

In [None]:
model = NetO().cuda()
state = torch.load(f'{MODEL_DIR}/{MODEL_NAME}')
hasopt = set(state)=={'model', 'opt'}
model_state = state['model'] if hasopt else state
model.load_state_dict(model_state, strict=True)
model.eval()

# Create Submission

In [None]:
# Clear working directory
!rm -r *

for city in CITIES:
    
    print(city)
    working_path = Path(f'{WORKING_DIR}/{city}')
    
    # Unzip data if it is not already done so
    if not working_path.exists():
        with tarfile.open(f'{INPUT_DIR}/{city}.tar') as tarred_file:
            files = [tarinfo for tarinfo in tarred_file.getmembers()
                     if tarinfo.name.startswith(f'{city}/testing/')
                     or tarinfo.name.startswith(f'{city}/{city}_static')] # Only unzipping the testing folder and static file
            tarred_file.extractall(members=files, path=WORKING_DIR)
    
    # load static features
    with h5py.File(f'{working_path}/{city}_static_2019.h5', 'r') as static_file:
        static_features = static_file.get('array')[()].astype(np.float32)
        static_features = torch.from_numpy(static_features).permute(2, 0, 1)
    static_features = static_features
    
    # Loop through each test date
    for date, frame in tqdm(test_slots.items()):
        
        with h5py.File(f'{working_path}/testing/{date}_test.h5', 'r') as h5_file:
            x = h5_file.get('array')[()]
            
        # Note dimension reordering, from (Batch Size, Time, Height, Width, Channels) to (Batch Size, Channels, Time, Height, Width)
        x = np.transpose(x, (0, 4, 1, 2, 3))
        x = torch.from_numpy(x).float()
            
        ##################################################################################################
        # Calculate output
        with torch.no_grad():
            
            N, C, D, H, W = x.shape
            x = x.reshape(N, C*D, H, W)
            s = torch.stack(N*[static_features], dim=0)
            t = torch.ones(N, 1, H, W)
            for j in range(N):
                t[j] = t[j] * frame[j] * 255. / (288. - 12)
            x = torch.cat([x, s, t], dim=1)
                
            if x.shape[0] > 3:
                y1 = model(x[:3].cuda()).cpu()
                y2 = model(x[3:].cuda()).cpu()
                y = torch.cat([y1, y2])
                del y1, y2
            else:
                y = model(x.cuda()).cpu()
            
            y = torch.round(y)
            y = torch.clamp(y, min=0, max=255)
        ##################################################################################################
        
        # Dimension reordering
        y = y.permute(0, 2, 3, 4, 1).byte()
        # Assume output.shape == input.shape, hence slice out the bit required for submission
        y = y[:,[0,1,2,5,8,11],:,:,:8]
        
        with h5py.File(f'{working_path}/{date}_test.h5', 'w') as h5_file:
            h5_file.create_dataset('array', data=y, compression="gzip", compression_opts=6)
            
        del x, y
        torch.cuda.empty_cache()
        
        # Delete the used files to save disk space...
        os.remove(f'{working_path}/testing/{date}_test.h5')
        
    # Delete data folder
    shutil.rmtree(f'{working_path}/testing')
    os.remove(f'{working_path}/{city}_static_2019.h5')

# Create .zip file
!zip -r0 submission.zip .