In [1]:
from glob import glob
from src import Module, DataModule, DatasetFCI, DatasetFCIm
import torch 
import torchmetrics 
from tqdm import tqdm
import os 
from skimage import io
import shutil
import numpy as np
import pandas as pd

In [2]:
checkpoints = glob('./checkpoints/*')
checkpoints

['./checkpoints/myunet-rs50-fcim-lrs-val_metric=0.68107-epoch=14.ckpt',
 './checkpoints/unetpp-rs50-fcim-lrs-val_metric=0.67217-epoch=16.ckpt',
 './checkpoints/myunet-serx101-fcim-lrs-val_metric=0.66931-epoch=14.ckpt',
 './checkpoints/unetpp-rs50-fcim-lrs-epoch=29.ckpt',
 './checkpoints/myunet-serx101-fcim-lrs-da-val_metric=0.65259-epoch=10.ckpt',
 './checkpoints/unetpp-rs50-fcim-lrs-epoch=99.ckpt',
 './checkpoints/unetpp-rs50-fcim-lrs-focal-val_metric=0.00334-epoch=0.ckpt',
 './checkpoints/unetpp-rs50-fcim-da-epoch=29.ckpt',
 './checkpoints/myunet-serx101-fcim-lrs-da-epoch=12.ckpt',
 './checkpoints/unetpp-rs50-fcim-da-epoch=29-v1.ckpt',
 './checkpoints/unetpp-rs50-fcim-da-val_metric=0.67048-epoch=26.ckpt',
 './checkpoints/myunet-serx101-fcim-lrs-epoch=29.ckpt',
 './checkpoints/unetpp-rs50-fcim-lrs-focal-epoch=9.ckpt',
 './checkpoints/unetpp-rs50-fcim-lrs-val_metric=0.68205-epoch=15.ckpt',
 './checkpoints/myunet-rs50-fcim-lrs-epoch=29.ckpt',
 './checkpoints/unetpp-rs50-fcim-da-val_metr

In [3]:
models = [
	('UnetPlusPlus', 'timm-resnest50d', 'unetpp-rs50-fcim-lrs-val_metric=0.68205-epoch=15.ckpt'),
	('MyUnet', 'resnest50d', 'myunet-rs50-fcim-lrs-val_metric=0.68107-epoch=14.ckpt')
]

In [31]:
def eval(models, device=1, mask_loss=True):
	dm = DataModule(Dataset="DatasetFCIm")
	dm.setup()
	all_preds = []
	for model in models:
		architecture, encoder, name = model
		print(name)
		checkpoint = f'./checkpoints/{name}'
		state_dict = torch.load(checkpoint, map_location='cpu')['state_dict']
		module = Module({"encoder": encoder, "in_chans": 5, "pretrained": None, "padding": 1, "mask_loss": True, "architecture": architecture})
		module.load_state_dict(state_dict)
		module.eval()
		module.cuda(device)
		preds, gt = torch.tensor([]), torch.tensor([])
		with torch.no_grad():
			for batch in tqdm(dm.val_dataloader()):
				x, y = batch
				if mask_loss: x = x[...,:-1]
				x = x.cuda(device)
				y_hat = module(x)
				probas = torch.sigmoid(y_hat)
				preds = torch.cat([preds, probas.cpu()])
				gt = torch.cat([gt, y])
		all_preds.append(preds)
	all_preds = torch.stack(all_preds).mean(0)
	metric = torchmetrics.Dice()
	metric(all_preds, gt.long().unsqueeze(1))
	return metric.compute().item()

				

In [32]:
eval(models)

unetpp-rs50-fcim-lrs-val_metric=0.68205-epoch=15.ckpt


100%|██████████| 36/36 [00:25<00:00,  1.40it/s]


myunet-rs50-fcim-lrs-val_metric=0.68107-epoch=14.ckpt


100%|██████████| 36/36 [00:20<00:00,  1.77it/s]


0.6871915459632874

In [33]:
images = glob('data/test_satellite/*.tif')
len(images)

1426

In [34]:
image_ids = sorted([image.split('/')[-1].split('_')[0] for image in images])
len(image_ids)

1426

In [35]:
test_ds = DatasetFCIm(image_ids, mode="test", image_folder="test_satellite")

image, image_id = test_ds[0]
image.shape, image.dtype, image.max(), image.min(), image_id

((350, 350, 6), dtype('float32'), 1.0, 0.0, 'AA408972')

In [41]:
def generate_masks(models, device=1, mask_loss=True, th=0.5):
	images = glob('data/test_satellite/*.tif')
	image_ids = sorted([image.split('/')[-1].split('_')[0] for image in images])
	test_ds = DatasetFCIm(image_ids, mode="test", image_folder="test_satellite")
	test_dl = torch.utils.data.DataLoader(test_ds, batch_size=10, shuffle=False, num_workers=4)
	!rm -rf ./data/test_kelp
	os.makedirs('./data/test_kelp', exist_ok=True)
	all_preds = []
	for model in models:
		architecture, encoder, name = model
		print(name)
		checkpoint = f'./checkpoints/{name}'
		state_dict = torch.load(checkpoint, map_location='cpu')['state_dict']
		module = Module({"encoder": encoder, "in_chans": 5, "pretrained": None, "padding": 1, "mask_loss": True, "architecture": architecture})
		module.load_state_dict(state_dict)
		module.eval()
		module.cuda(device)
		preds = torch.tensor([])
		with torch.no_grad():
			for batch in tqdm(test_dl):
				x, _ = batch
				if mask_loss: x = x[...,:-1]
				x = x.cuda(device)
				y_hat = module(x)
				probas = torch.sigmoid(y_hat)
				preds = torch.cat([preds, probas.cpu()])
		all_preds.append(preds)
	all_preds = torch.stack(all_preds).mean(0)
	all_preds = all_preds > th
	all_preds = all_preds.long().squeeze(1)
	all_preds = all_preds.cpu().numpy().astype('uint8')
	for i, image_id in enumerate(image_ids):
		io.imsave(f'data/test_kelp/{image_id}_kelp.tif', all_preds[i])

In [42]:
generate_masks(models)

unetpp-rs50-fcim-lrs-val_metric=0.68205-epoch=15.ckpt


100%|██████████| 143/143 [00:34<00:00,  4.15it/s]


myunet-rs50-fcim-lrs-val_metric=0.68107-epoch=14.ckpt


100%|██████████| 143/143 [00:24<00:00,  5.91it/s]
  io.imsave(f'data/test_kelp/{image_id}_kelp.tif', all_preds[i])
  io.imsave(f'data/test_kelp/{image_id}_kelp.tif', all_preds[i])
  io.imsave(f'data/test_kelp/{image_id}_kelp.tif', all_preds[i])
  io.imsave(f'data/test_kelp/{image_id}_kelp.tif', all_preds[i])
  io.imsave(f'data/test_kelp/{image_id}_kelp.tif', all_preds[i])
  io.imsave(f'data/test_kelp/{image_id}_kelp.tif', all_preds[i])
  io.imsave(f'data/test_kelp/{image_id}_kelp.tif', all_preds[i])
  io.imsave(f'data/test_kelp/{image_id}_kelp.tif', all_preds[i])
  io.imsave(f'data/test_kelp/{image_id}_kelp.tif', all_preds[i])
  io.imsave(f'data/test_kelp/{image_id}_kelp.tif', all_preds[i])
  io.imsave(f'data/test_kelp/{image_id}_kelp.tif', all_preds[i])
  io.imsave(f'data/test_kelp/{image_id}_kelp.tif', all_preds[i])
  io.imsave(f'data/test_kelp/{image_id}_kelp.tif', all_preds[i])
  io.imsave(f'data/test_kelp/{image_id}_kelp.tif', all_preds[i])
  io.imsave(f'data/test_kelp/{image_id}_

In [43]:
generated_masks = glob('data/test_kelp/*.tif')
len(generated_masks)

1426

In [44]:
mask = io.imread(generated_masks[0])
mask.shape, mask.dtype, mask.max(), mask.min()

((350, 350), dtype('uint8'), 1, 0)

In [45]:
!rm -rf submission.zip
shutil.make_archive('submission', 'zip', 'data/test_kelp')

'/home/juan/Desktop/competis/KelpWanted/submission.zip'

In [14]:
examples = glob('data/submission_format/*.tif')
len(examples)

1426

In [15]:
mask = io.imread(examples[0])
mask.shape, mask.dtype, mask.max(), mask.min()

((350, 350), dtype('uint8'), 1, 0)