In [2]:
import l5kit
import numpy as np
import torch
from joblib import Parallel, delayed
import time
from l5kit.evaluation import write_pred_csv, read_pred_csv
from tqdm.notebook import tqdm

batch_size = 64
N = 71122

In [3]:
def ensembleTorch(outputs, weights, confidences):
    
    # all matrices: n,sc,bs,sc2,tl,2
    
    sc2 = 3
    n,bs,sc,tl,_ = outputs.shape
    assert (n,bs,sc) == confidences.shape
    assert n == len(weights)
    
    xij = outputs.clone().transpose(1,2)[:,:,:,None,:,:]
    x = outputs[0,:,:sc2].clone()
    cf = confidences.clone().transpose(1,2)
    cf = (weights[:,None,None] * cf)[:,:,:,None,None,None]
    
    ck = confidences[0,:,:sc2].clone()
    if sc != sc2:
        ck = ck / ck.sum(1,keepdims=True)
    
    for m in range(20):
        any_update = False
        # one trajectory at a time, for convergence purposes
        for s in range(sc2):
            dist = ((x[None,None,:,:,:,:] - xij)**2).sum(5,keepdims=True)
            log_ck = torch.log(ck[None,None,:,:,None,None])
            eij = log_ck - 0.5*dist.sum(4,keepdims=True)
            values,_ = eij.max(3,keepdims=True)
            eij = torch.exp(eij - values)
            sum_eij = eij.sum(3,keepdims=True)

            mij = cf*eij/sum_eij

            x_new = (mij*xij).sum((0,1))/torch.clamp(mij.sum((0,1)),1e-9)
            change = (x[:,s] - x_new[:,s]).abs().max()
            assert not torch.isnan(change)
            if change > 1e-3:
                x[:,s] = x_new[:,s]
                any_update = True
            del dist,eij,values,sum_eij,mij,x_new,change,log_ck
        
        # repeat the same for confidences
        dist = ((x[None,None,:,:,:,:] - xij)**2).sum(5,keepdims=True)
        log_ck = torch.log(ck[None,None,:,:,None,None])
        eij = log_ck - 0.5*dist.sum(4,keepdims=True)
        values,_ = eij.max(3,keepdims=True)
        eij = torch.exp(eij - values)
        sum_eij = eij.sum(3,keepdims=True)

        mij = cf*eij/sum_eij

        ck_new = mij.sum((0,1)).squeeze(dim=3).squeeze(dim=2)
        ck_new = ck_new / ck_new.sum(1,keepdims=True)
        change = (ck - ck_new).abs().max()
        assert not ((m > 10) and (change > 1))
        assert not torch.isnan(change)
        if change > 1e-4:
            ck = ck_new
            any_update = True
        del dist,eij,values,sum_eij,mij,ck_new,change,log_ck
        
        if not any_update:
            break
    
    del xij,cf
    
    return x,ck

In [4]:
@torch.no_grad()
def run_ensemble(names, weights):

    st = time.time()
    weights /= weights.sum()
    weights = weights.cuda()
    
    pred_coords_list = []
    confidences_list = []
    timestamps_list = []
    track_id_list = []
    files = zip(*[read_pred_csv(name) for name in names])

    collected = 0
    for i,z in tqdm(enumerate(files), total=N):
        if collected == 0:
            coords = []
            confs = []

        coords.append(torch.stack([torch.tensor(row['coords']) for row in z], dim=0))
        confs.append(torch.stack([torch.tensor(row['conf']) for row in z]))
        timestamps_list.append(z[0]["timestamp"])
        track_id_list.append(z[0]["track_id"])
        collected += 1

        def batch_processing(coords, confs):
            coords = torch.stack(coords, dim=1)
            confs = torch.stack(confs, dim=1)
            x, ck = ensembleTorch(coords.cuda(), weights, confs.cuda())
            pred_coords_list.append(x.cpu().detach().numpy())
            confidences_list.append(ck.cpu().detach().numpy())

        if (collected == batch_size) or (i == (N-1)):
            batch_processing(coords, confs)
            collected = 0

    coords = np.concatenate(pred_coords_list)
    confs = np.concatenate(confidences_list)
    timestamps = np.array(timestamps_list)
    track_ids = np.array(track_id_list)

    print('running time:', time.time() - st)

    st = time.time()
    write_pred_csv(
        "submission.csv",
        timestamps=timestamps,
        track_ids=track_ids,
        coords=coords, 
        confs=confs)
    print('write csv time:', time.time() - st)

In [5]:
names = ['submission12285.csv','submission12301.csv','submission12976.csv']
weights = torch.tensor([0.4,0.3,0.3])

In [6]:
names = ['submission_B3.csv','submission_B5.csv','submission_B6_1.csv','submission_B6_2.csv']
weights = torch.tensor([0.4,0.25,0.25,0.25])

In [7]:
run_ensemble(names, weights)

HBox(children=(FloatProgress(value=0.0, max=71122.0), HTML(value='')))


running time: 221.61924505233765
write csv time: 35.55594205856323
