In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch 
from src import AllModule

state_dict = torch.load('checkpoints/all_r34-val_error=0.70467-epoch=3.ckpt')['state_dict']
model = AllModule(dict(backbone='resnet34', pretrained=True, mlp_layers=[256, 512], mlp_dropout=0.))
model.load_state_dict(state_dict)

<All keys matched successfully>

In [3]:
from tqdm import tqdm
from src import AllDataModule

model.cuda()

tta_trans = [
    None,
    {'HorizontalFlip': {'p': 1}},
    {'VerticalFlip': {'p': 1}},
    {'Transpose': {'p': 1}},
    {'RandomRotate90': {'p': 1}},
]

dm = AllDataModule(batch_size=512, num_workers=10, pin_memory=True)
dm.setup()

all_probas = []
for r, trans in enumerate(tta_trans):
    dm.test_trans = trans 
    dm.generate_datasets()
    probas, observations = torch.tensor([]), []
    for batch in tqdm(dm.test_dataloader()):
        preds = model.predict(batch)
        probas = torch.concat([probas, preds.cpu()], dim=0)
        observation_ids = batch['observation_id']
        observations += observation_ids.cpu().numpy().tolist()
    all_probas.append(probas)
all_probas = torch.stack(all_probas, dim=0) # 5, N, D

train: 1587395
val: 40080
test: 36421


100%|██████████| 72/72 [01:18<00:00,  1.09s/it]


Compose([
  HorizontalFlip(always_apply=False, p=1),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={'nir': 'image', 'alt': 'image', 'lc': 'image'})


100%|██████████| 72/72 [01:13<00:00,  1.02s/it]


Compose([
  VerticalFlip(always_apply=False, p=1),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={'nir': 'image', 'alt': 'image', 'lc': 'image'})


100%|██████████| 72/72 [01:15<00:00,  1.06s/it]


Compose([
  Transpose(always_apply=False, p=1),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={'nir': 'image', 'alt': 'image', 'lc': 'image'})


100%|██████████| 72/72 [01:15<00:00,  1.05s/it]


Compose([
  RandomRotate90(always_apply=False, p=1),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={'nir': 'image', 'alt': 'image', 'lc': 'image'})


100%|██████████| 72/72 [01:15<00:00,  1.05s/it]


In [7]:
mean_probas = all_probas.mean(dim=0) # N, D
values, ixs = mean_probas.topk(30)
final_labels = [' '.join([str(i.item()) for i in ix]) for ix in ixs]

In [10]:
import pandas as pd 

submission = pd.DataFrame({'Id': observations, 'Predicted': final_labels})
submission.to_csv('submission.csv', index=False)
submission.sample(5)

Unnamed: 0,Id,Predicted
33046,21703844,125 2821 2902 3495 2619 4950 760 2822 2196 553...
12102,10153678,21 338 216 436 437 694 496 213 932 510 278 270...
13495,10288281,476 78 208 206 1949 170 495 645 218 514 195 15...
34644,21877720,6362 4064 5444 2902 1381 3495 5145 5068 5317 4...
36282,22056159,5228 5207 5045 5030 5039 5218 3072 5434 5270 6...


In [9]:
sample_submission = pd.read_csv('data/sample_submission.csv')
assert len(sample_submission) == len(submission)
sample_submission.sample(10)

Unnamed: 0,Id,Predicted
3309,10157629,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
1689,10080870,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
3055,10146125,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
32874,21685132,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
8204,10388907,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
30101,21390964,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
10826,10519692,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
4271,10202332,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
25904,20942337,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
19837,20280142,0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18...
