In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch 
from src import AllModule

state_dict = torch.load('checkpoints/all_r34-val_error=0.70467-epoch=3.ckpt')['state_dict']
model = AllModule(dict(backbone='resnet18', pretrained=True, mlp_layers=[256, 512], mlp_dropout=0.))
model.load_state_dict(state_dict)

<All keys matched successfully>

In [3]:
from tqdm import tqdm
from src import AllDataModule

model.cuda()

tta_trans = [
    None,
    {'HorizontalFlip': {'p': 1}},
    {'VerticalFlip': {'p': 1}},
    {'Transpose': {'p': 1}},
    {'RandomRotate90': {'p': 1}},
]

dm = AllDataModule(batch_size=512, num_workers=10, pin_memory=True)
dm.setup()

all_probas = []
for r, trans in enumerate(tta_trans):
    dm.test_trans = trans 
    dm.generate_datasets()
    probas, observations = torch.tensor([]), []
    for batch in tqdm(dm.test_dataloader()):
        preds = model.predict(batch)
        probas = torch.concat([probas, preds.cpu()], dim=0)
        observation_ids = batch['observation_id']
        observations += observation_ids.cpu().numpy().tolist()
    all_probas.append(probas)
all_probas = torch.stack(all_probas, dim=0) # 5, N, D

train: 1587395
val: 40080
test: 36421


100%|██████████| 72/72 [00:55<00:00,  1.30it/s]


Compose([
  HorizontalFlip(always_apply=False, p=1),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={'nir': 'image', 'alt': 'image', 'lc': 'image'})


100%|██████████| 72/72 [00:53<00:00,  1.35it/s]


Compose([
  VerticalFlip(always_apply=False, p=1),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={'nir': 'image', 'alt': 'image', 'lc': 'image'})


100%|██████████| 72/72 [00:54<00:00,  1.32it/s]


Compose([
  Transpose(always_apply=False, p=1),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={'nir': 'image', 'alt': 'image', 'lc': 'image'})


100%|██████████| 72/72 [00:54<00:00,  1.32it/s]


Compose([
  RandomRotate90(always_apply=False, p=1),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={'nir': 'image', 'alt': 'image', 'lc': 'image'})


100%|██████████| 72/72 [00:55<00:00,  1.29it/s]


In [4]:
mean_probas = all_probas.mean(dim=0) # N, D
values, ixs = mean_probas.topk(30)
final_labels = [' '.join([str(i.item()) for i in ix]) for ix in ixs]

In [5]:
import pandas as pd 

submission = pd.DataFrame({'Id': observations, 'Predicted': final_labels})
submission.to_csv('submission.csv', index=False)
submission.sample(5)

Unnamed: 0,Id,Predicted
12569,10058120,858 227 253 1041 421 472 160 305 238 429 1030 ...
5327,10527804,376 938 202 1174 82 273 447 802 879 864 452 44...
30361,21417062,5539 6059 5295 5441 7415 5694 6393 9423 9150 9...
17299,20000341,3253 5109 2093 5124 5418 5004 4894 13818 5068 ...
5273,10515131,1041 49 263 380 660 29 256 103 287 166 397 748...


In [6]:
sample_submission = pd.read_csv('data/sample_submission.csv')
assert len(sample_submission) == len(submission)
sample_submission.sample(10)

Unnamed: 0,Id,Predicted
22620,20576555,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
19925,20289050,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
33993,21804222,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
2618,10124788,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
30673,21449931,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
8220,10389641,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
8198,10388766,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
14318,10684002,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
30843,21470582,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
5630,10268444,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
