In [None]:
!pip install transformers

In [None]:
!pip install lightning

In [None]:
!pip install -U git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
pip install wandb

In [None]:
import os
import cv2
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision import models
from datasets import load_metric
from torchmetrics.functional import dice
import torch.optim as optim
import lightning as pl
import segmentation_models_pytorch as smp

from transformers import SegformerForSemanticSegmentation
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import KFold
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Utils

In [None]:
# RLE 디코딩 함수
def rle_decode(mask_rle, shape):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)

#RLE 인코딩 함수
def rle_encode(mask):
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

### Custom Dataset

In [None]:
class SatelliteDataset(Dataset):
    def __init__(self, csv_file, image_dir, transform=None, infer=False):
        self.data = pd.read_csv(csv_file)
        self.image_dir = image_dir
        self.transform = transform
        self.infer = infer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_filename = self.data.iloc[idx, 1].lstrip('.')
        img_path = self.image_dir + img_filename
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.infer:
            if self.transform:
                image = self.transform(image=image)['image']
                
            sample = {'pixel_values': image}
            
            return sample

        mask_rle = self.data.iloc[idx, 2]
        mask = rle_decode(mask_rle, (image.shape[0], image.shape[1]))

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        
        mask = torch.unsqueeze(mask, dim=0)
    
        sample = {'pixel_values': image, 'labels': mask}
        
        return sample

### DataLoader

In [None]:
transform = A.Compose(
    [
        A.RandomCrop(224, 224),
        A.Normalize(),
        A.Rotate(limit=60),
        A.CoarseDropout(max_holes=8, max_height=0.25, max_width=0.25, min_holes=3, min_height=0.125, min_width=0.125, fill_value=0, mask_fill_value=0, p=1),
        ToTensorV2()
    ]

)

In [None]:
dataset = SatelliteDataset(csv_file='/kaggle/input/dacon-building-data/train.csv', image_dir='/kaggle/input/dacon-building-data', transform=transform)
## 배치 사이즈 설정
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)

In [None]:
test_dataset = SatelliteDataset(csv_file='/kaggle/input/dacon-building-data/test.csv', image_dir='/kaggle/input/dacon-building-data', transform=transform, infer=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2)

### Build Model

In [None]:
class SegFormerModel(pl.LightningModule):
    def __init__(self, train_dataloader=None, val_dataloader=None, test_dataloader=None, metrics_interval=10, model=None):
        super(SegFormerModel, self).__init__()
        #self.learningrate = learning_rate
        self.metrics_interval = metrics_interval
        self.train_dl = train_dataloader
        self.val_dl = val_dataloader
        self.test_dl = test_dataloader
        self.model = model or get_initial_model()
        #self.loss_module = smp.losses.SoftBCEWithLogitsLoss(weight=None, reduction=None, smooth_factor=None, pos_weight=None)
        self.loss_module = smp.losses.DiceLoss(mode="binary", smooth=1.0, from_logits=True)
        self.train_step_ious= []
        self.validation_step_ious = []
        self.validation_step_outputs = []
        self.test_step_outputs = []
        self.save_hyperparameters()
        
    def forward(self, images, masks=None):
        outputs = self.model(pixel_values=images)
        return outputs
    
    def training_step(self, batch, batch_idx):
        masks = torch.squeeze(batch['labels']).long().to(device)
        masks = masks.unsqueeze(1).to(device)
        images = batch['pixel_values'].float().to(device)
            
        outputs = self.model(pixel_values=images, return_dict=True)
        
        upsampled_logits = nn.functional.interpolate(
            outputs.logits, 
            size=masks.shape[-2:], 
            mode="bilinear", 
            align_corners=False
        ).contiguous().to(device)
        
    # predicted = upsampled_logits.argmax(dim=1)
        loss = self.loss_module(upsampled_logits, masks)
        tp, fp, fn, tn = smp.metrics.get_stats((upsampled_logits.sigmoid()>0.5).long(), masks.long(), mode='binary')
        iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        self.train_step_ious.append(iou)
    
        if batch_idx % self.metrics_interval == 0:
            mean_iou = torch.stack(self.train_step_ious).mean()
            # Log loss and metric
            self.log('train_loss', loss)
            self.log('train_mean_iou',  mean_iou)
            
            print(f"Training loss: {loss:.5f}")
            print("\n-----------------------")

        return {'loss': loss}
    
    
#     def validation_step(self, batch, batch_idx):
#         masks =  torch.squeeze(batch['labels']).long().to(device)
#         masks = nn.functional.one_hot(masks, num_classes=self.num_classes).permute(0, 3, 1, 2).contiguous().to(device)
#         images = batch['pixel_values'].float().to(device)
        
#         outputs = self.model(pixel_values=images, return_dict=True)
        
#         upsampled_logits = nn.functional.interpolate(
#             outputs.logits, 
#             size=masks.shape[-2:], 
#             mode="bilinear", 
#             align_corners=False
#         ).contiguous()

#         predicted = upsampled_logits.argmax(dim=1).to(device)
#         loss = self.loss_module(upsampled_logits, masks)
    
#         tp, fp, fn, tn = smp.metrics.get_stats((upsampled_logits.sigmoid()>0.5).long(), masks.long(), mode='binary')
#         iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        
#         self.validation_step_ious.append(iou)
#         self.validation_step_outputs.append(loss)
        
#         # Log loss and metric
#         self.log('val_loss', loss)
#         self.log(f"IoU", iou)
        
#         print(f"Val Batch {batch_idx+1}: Metrics")
#         print(f"-----------------------\nStep Validation Loss: {loss:.5f}")
#         print("\n-----------------------")
        
#         return {'val_loss': loss, 'predicted': predicted}
    
    
#     def on_validation_epoch_end(self):
#         epoch_average_loss = torch.stack(self.validation_step_outputs).mean()
#         val_step_mean_iou = torch.stack(self.validation_step_ious).mean()
 
#         metrics = {"val_loss": epoch_average_loss, "val_mean_iou":val_step_mean_iou, }
        
#         print(f"Val Epoch Metrics")
#         print(f"Epoch IoU score: {val_step_mean_iou:.3f}\n-----------------------")    
#         self.validation_step_outputs.clear()  # free memory
#         return metrics
    
#     def test_step(self, batch, batch_idx):
#         images = batch['pixel_values'].float().to(device)
#         outputs = self.model(images, return_dict=True)
        
#         upsampled_logits = nn.functional.interpolate(
#             outputs.logits, 
#             size=images.shape[-2:], 
#             mode="bilinear", 
#             align_corners=False
#         ).contiguous()
        
#         return upsampled_logits
        
        
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        images = batch['pixel_values'].float()
        return self.model(images, return_dict=True)
        
    def configure_optimizers(self):
        ## lr 설정 필요
        optimizer = torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=1e-03, eps=1e-07)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max= 100, eta_min= 1e-06, last_epoch= -1)
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "step"}, "monitor": "train_loss"}
    
    def train_dataloader(self):
        return self.train_dl
    
#     def val_dataloader(self):
#         return self.val_dl
    
    def test_dataloader(self):
        return self.test_dl

In [None]:
def get_initial_model():
    return SegformerForSemanticSegmentation.from_pretrained(
            "nvidia/mit-b3", 
            return_dict=True, 
            num_labels=1,
            ignore_mismatched_sizes=True,
            )

In [None]:
import wandb

# loss 모니터링 외부api wanb사이트에서 키 받아서 수정
wandb_api = '0b6338174b96d25fc3dd9e12a4cf72eef795d891'
wandb.login(key=wandb_api)

In [None]:
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor


# 얼리스탑 여부 수정 부분
# min_delta -> 최소 로스 변화량
# patience -> 값이 k번 이내로 줄어야함
early_stop_callback = EarlyStopping(
    monitor="train_loss", 
    min_delta=0.01, 
    patience=3, 
    verbose=False, 
    mode="min",
)

checkpoint_callback = ModelCheckpoint(dirpath='/kaggle/working/checkpoint', save_top_k=1, monitor="train_loss")

wandb_logger = WandbLogger(project='seg1', log_model='all')

# class FineTuneBatchSizeFinder(BatchSizeFinder):
#     def __init__(self, milestones, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.milestones = milestones

#     def on_fit_start(self, *args, **kwargs):
#         return

#     def on_train_epoch_start(self, trainer, pl_module):
#         if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
#             self.scale_batch_size(trainer, pl_module)
            
            
# batch_size_callback = FineTuneBatchSizeFinder(milestones=(5, 10))

lr_monitor_callback = LearningRateMonitor(logging_interval='step')

In [None]:
segformer = SegFormerModel( 
    train_dataloader = dataloader,
    metrics_interval = 5
)

trainer = pl.Trainer(
    callbacks=[early_stop_callback, checkpoint_callback, lr_monitor_callback],
    max_epochs=6,
    accelerator="gpu",
    devices = 1,
    logger=wandb_logger
)

In [None]:
trainer.fit(segformer)

### Save & Load Model

In [None]:
checkpoint_callback.best_model_path

In [None]:
checkpoint_model = segformer.load_from_checkpoint(checkpoint_callback.best_model_path, map_location=torch.device('cpu'))

### Inference

In [None]:
# load checkpoint
#checkpoint_model = segformer.load_from_checkpoint('/kaggle/input/checkpoint/epoch0-step447 (1).ckpt', map_location=torch.device('cpu'))

checkpoint_model.eval()

In [None]:
outputs = trainer.predict(checkpoint_model, test_dataloader)

In [None]:
result = []

for i,data in tqdm(enumerate(test_dataloader)):
    image = data['pixel_values'].to('cpu')
    upsampled_logits = nn.functional.interpolate(
        outputs[i].logits, 
        size=image.shape[-2:], 
        mode="bilinear", 
        align_corners=False
    ).contiguous().to('cpu')
    predicted = (torch.sigmoid(upsampled_logits) > 0.5).cpu().numpy()
    predicted = np.squeeze(predicted, axis=1)
    mask_rle = rle_encode(predicted)
    if mask_rle == '': # 예측된 건물 픽셀이 아예 없는 경우 -1
        result.append(-1)
    else:
        result.append(mask_rle)

### Submission

In [None]:
submit = pd.read_csv('/kaggle/input/dacon-building-data/sample_submission.csv')
submit['mask_rle'] = result

In [None]:
submit.to_csv('./submit.csv', index=False)