# TTA (Test Time Augmentation)

In [None]:
from modules.utils import load_yaml, rle_encode
from modules.model import get_smp_model
from modules.dataset import smpDataset
from modules.augmentation import *

import os
import random
from tqdm import tqdm

import torch
import numpy as np
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
prj_dir = './'
config_path = os.path.join(prj_dir, 'config', './predict_smp_512.yaml')
config = load_yaml(config_path)

In [None]:
torch.cuda.manual_seed(config['seed'])
torch.manual_seed(config['seed'])
np.random.seed(config['seed'])
random.seed(config['seed'])
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
df = pd.read_csv(f"data/test.csv")

images = []
for img in df["img_path"]:
    path = f"./data/test_img/{os.path.basename(img)}"
    images.append(path)

In [None]:
test_dataset = smpDataset(images=images,
                          masks=None,
                          transform=None,
                          infer=True)

print('test len:', len(test_dataset))

In [None]:
model = get_smp_model(name=config['model']['architecture'])

model = model(encoder_name=config['model']['encoder'],
              encoder_weights=config['model']['encoder_weight'],
              in_channels=config['model']['in_channel'],
              classes=config['model']['n_classes'],
)
model.to(device)

weights = torch.load(config['model']['pretrained'])
model.load_state_dict(weights['model'])

In [None]:
class TTA:
    def __init__(self, transform):
        self.transform = transform

        self.transforms = [
            A.VerticalFlip(p=1),
            A.HorizontalFlip(p=1),
            A.Compose([
                A.HorizontalFlip(p=1),
                A.VerticalFlip(p=1),
            ]),
            A.Rotate(limit=[90, 90], p=1),
            A.Rotate(limit=[-90, -90], p=1),
        ]

    def process(self, image):
        images = [self.transform(image=image)['image']]

        for t in self.transforms:
            images.append(self.transform(image=t(image=image)['image'])['image'])

        return torch.tensor(np.array(images), dtype=torch.float32)

    def unprocess(self, images):
        results = []

        results.append(images[0])
        results.append(self.transforms[0](image=images[1])['image'])
        results.append(self.transforms[1](image=images[2])['image'])
        results.append(self.transforms[2](image=images[3])['image'])
        results.append(self.transforms[4](image=images[4])['image'])
        results.append(self.transforms[3](image=images[5])['image'])

        return results

transform = A.Compose([
    A.Normalize(),
    ToTensorV2(),
])

tta = TTA(transform=transform)

In [None]:
result = []

with torch.no_grad():
    model.eval()
    for image, filename in tqdm(test_dataset):
        images = tta.process(image)
        images = images.to(device)

        predicts = model(images)

        seg_prob = torch.sigmoid(predicts).detach().cpu().numpy().squeeze()
        seg = (seg_prob > 0.5).astype(np.uint8)

        tta_seg = tta.unprocess(seg)

        tta_prob = (tta_seg[0] + tta_seg[1] + tta_seg[2] + tta_seg[3] + tta_seg[4] + tta_seg[5])
        tta_image = (tta_prob >= 3.0).astype(np.uint8)
    
        mask_rle = rle_encode(tta_image)
        if mask_rle == '':
            result.append(-1)
        else:
            result.append(mask_rle)

# Submission

In [None]:
submit = pd.read_csv('./data/sample_submission.csv')
submit['mask_rle'] = result

In [None]:
submit.to_csv('./submit25.csv', index=False)

In [None]:
# from dacon_submit_api import dacon_submit_api 

# result = dacon_submit_api.post_submission_file(
#     './submit22.csv', 
#     '5eca29221f8e6e442f5d55ccb7455756c26e5f85a5d1aac8208a97db790bbdb9', 
#     '236092', 
#     'ADED', 
#     ''
# )