In [None]:
from modules.utils import load_yaml, rle_encode
from modules.model import get_model

from modules.dataset import CustomDataset

In [None]:
import os
from glob import glob
from tqdm import tqdm

import torch
import torch.nn as nn
import numpy as np
import albumentations as A
import pandas as pd
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader

from transformers import (
    AutoImageProcessor
)

In [None]:
prj_dir = './'
config_path = os.path.join(prj_dir, 'config', 'predict.yaml')
config = load_yaml(config_path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
test_transfrom = A.Compose([
    A.Resize(config['input_size'], config['input_size'])
])

In [None]:
df = pd.read_csv(f"data/test.csv")

In [None]:
images = []
for img in df["img_path"]:
    path = f"./data/test_img/{os.path.basename(img)}"
    images.append(path)
    

In [None]:
#images = sorted(glob(f"{config['data_dir']}/*.png"))

print('test len:', len(images))

processor = AutoImageProcessor.from_pretrained(config['processor'])

test_dataset = CustomDataset(processor=processor, 
                             images=images, 
                             masks=None, 
                             transform=test_transfrom, 
                             infer=True)

test_dataloader = DataLoader(test_dataset, 
                              batch_size=config['batch_size'], 
                              num_workers=config['num_workers'],
                              shuffle=False)

In [None]:
model = get_model(name=config['model']['name'])
model = model.from_pretrained(
    config['model']['weight'],
    num_labels=1,
    ignore_mismatched_sizes=True
).to(device)

weights = torch.load(config['model']['pretrained'])
model.load_state_dict(weights)

print('model')

In [None]:
def postprocess(outputs):
    predicts = nn.functional.interpolate(
        outputs.logits,
        size=(config['input_size'], config['input_size']),
        mode="bilinear",
        align_corners=False,
    )

    return predicts

In [None]:
# with torch.no_grad():
#     model.eval()
#     if config['amp']:
#         for idx, (batch, filenames) in enumerate(tqdm(test_dataloader)):
#             images = batch["pixel_values"].to(device)
#             outputs = model(images)
#             predicts = postprocess(outputs)

#             seg_prob = torch.sigmoid(predicts).detach().cpu().numpy().squeeze()
#             seg = (seg_prob > 0.5).astype(np.uint8)
#             break

In [None]:
result = []

with torch.no_grad():
    model.eval()
    if config['amp']:
        for idx, (batch, filenames) in enumerate(tqdm(test_dataloader)):
            images = batch["pixel_values"].to(device)
            outputs = model(images)
            predicts = postprocess(outputs)
            
            seg_prob = torch.sigmoid(predicts).detach().cpu().numpy().squeeze()
            seg = (seg_prob > 0.5).astype(np.uint8)
            
            for i in range(len(images)):
                mask_rle = rle_encode(seg[i])
                if mask_rle == '':
                    result.append(-1)
                else:
                    result.append(mask_rle)

# Submission

In [None]:
submit = pd.read_csv('./data/sample_submission.csv')
submit['mask_rle'] = result

In [None]:
submit.to_csv('./submit2.csv', index=False)