In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
import pandas as pd

In [46]:
checkpoints = os.listdir('./checkpoints')
checkpoints

['baseline-val_metric=0.47755-epoch=27.ckpt', 'baseline-epoch=33.ckpt']

In [47]:
from src.module import Module
import torch 

name = "baseline-val_metric=0.47755-epoch=27.ckpt"
checkpoint = f'./checkpoints/{name}'

module = Module.load_from_checkpoint(checkpoint)

module.cpu()
module.eval()

traced = torch.jit.trace(module.model, torch.rand(10, 9, 256, 256))
traced.save(f'kaggle-dataset/{name[:-5]}.pt')

In [48]:
loaded = torch.jit.load(f'kaggle-dataset/{name[:-5]}.pt')
loaded;


In [49]:
from src.dm import DataModule
import torchmetrics
from tqdm import tqdm

dm = DataModule()
dm.setup()

loaded.eval()
loaded.cuda(1)

metric = torchmetrics.Dice().cuda(1)

with torch.no_grad():
	for batch in tqdm(dm.val_dataloader()):
		x, y = batch
		y_hat = loaded(x.cuda(1))
		metric(y_hat, y.cuda(1))

metric.compute().item()

100%|██████████| 116/116 [00:05<00:00, 22.16it/s]


0.4882103204727173

In [23]:
path = Path('/fastdata/contrails')

records = os.listdir(path / 'test')
len(records)

2

In [24]:
stats = pd.read_csv(path/'stats.csv', index_col=0)
stats

Unnamed: 0,min,max,mean,std
8,175.82391,280.14868,233.67686,4.545741
9,180.74695,279.4869,242.25447,6.057177
10,181.44263,331.13394,250.75069,7.620164
11,179.33739,332.06036,274.41205,13.668153
12,187.69131,306.21823,255.52716,8.708825
13,179.3451,338.0567,276.60184,14.446373
14,178.71164,338.6333,275.3594,14.736154
15,178.36511,333.21048,272.5641,14.367307
16,137.39153,311.97977,260.4258,11.04905


In [25]:
import torch 
import numpy as np

class Dataset(torch.utils.data.Dataset):
    def __init__(self, records):
        self.records = records

    def __len__(self):
        return len(self.records)
    
    def preprocess(self, record):
        bands = range(8,17)
        data = []
        for band in bands:
            image = np.load(path / 'test' / record / f'band_{band:02d}.npy')
            image = (image - stats.loc[band]['mean']) / stats.loc[band]['std']
            data.append(image)
        data = np.stack(data, axis=-1)
        return data[...,4,:]
    
    def __getitem__(self, ix):
        record = self.records[ix]
        image = self.preprocess(record)
        return record, torch.from_numpy(image).permute(2,0,1)

In [26]:
ds = Dataset(records)
dl = torch.utils.data.DataLoader(ds, batch_size=8, num_workers=4, pin_memory=True)

In [27]:
def rle_encode(x, fg_val=1):
    """
    Args:
        x:  numpy array of shape (height, width), 1 - mask, 0 - background
    Returns: run length encoding as list
    """

    dots = np.where(
        x.T.flatten() == fg_val)[0]  # .T sets Fortran order down-then-right
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths

def list_to_string(x):
    """
    Converts list to a string representation
    Empty list returns '-'
    """
    if x: # non-empty list
        s = str(x).replace("[", "").replace("]", "").replace(",", "")
    else:
        s = '-'
    return s

In [28]:
submission = {'record_id': [],	'encoded_pixels': []}

In [29]:
loaded.eval()
loaded.cuda(1)
with torch.no_grad():
    for records, x in dl:
        y_hat = loaded(x.cuda(1))
        masks = y_hat.sigmoid().cpu().numpy() > 0.5
        masks = masks.astype(np.int32)
        for ix in range(masks.shape[0]):
            mask = masks[ix][0]
            rle = rle_encode(mask)
            record = records[ix]
            submission['record_id'].append(record)
            if len(rle) == 0: 
                submission['encoded_pixels'].append('-')
            else:
                submission['encoded_pixels'].append(' '.join(map(str,rle)))

In [30]:
submission = pd.DataFrame(submission)
submission.to_csv('submission.csv', index=False)
submission

Unnamed: 0,record_id,encoded_pixels
0,1002653297254493116,-
1,1000834164244036115,-


In [None]:
!kaggle datasets version -m "update" -p kaggle-dataset