<a href="https://www.kaggle.com/code/dlbkvv/exam-ships-segmentation?scriptVersionId=185639066" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
!pip install -q segmentation-models-pytorch

In [None]:
!pip install -q torchsummary

In [None]:
import numpy as np 
import pandas as pd 
import torch
import os
from torchvision import transforms
from sklearn.model_selection import train_test_split
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn
from torchsummary import summary
from segmentation_models_pytorch.losses import DiceLoss
from segmentation_models_pytorch import Unet
import torchmetrics

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
device

In [None]:
df = pd.read_csv('/kaggle/input/itstep-exam2/ship_segmentations.csv')
img_dir = '/kaggle/input/itstep-exam2/ship_images'

In [None]:
df.head()

In [None]:
df.describe()

In [None]:
df.info()

In [None]:
count_ships_df = df.groupby('ImageId').count()
count_ships_df

In [None]:
count_img_to_ships = {}
max_ships_img = count_ships_df['EncodedPixels'].max()
for i in range(0, max_ships_img +1):
    temp_count = count_ships_df[count_ships_df['EncodedPixels'] == i].count().iloc[0]
    count_img_to_ships[i] = temp_count

count_img_to_ships = pd.DataFrame(list(count_img_to_ships.items()), columns=['ShipCount', 'ImageCount'])
count_img_to_ships

In [None]:
total_img_with_ships = count_img_to_ships['ImageCount'].sum() - 150_000
total_img_with_ships

In [None]:

plt.bar(count_img_to_ships['ShipCount'], count_img_to_ships['ImageCount'], edgecolor='black')
plt.xlabel('img_count')
plt.ylabel('ships')
plt.grid(True)
plt.show()

In [None]:

plt.bar(count_img_to_ships['ShipCount'], count_img_to_ships['ImageCount'], edgecolor='black')
plt.xlabel('img_count')
plt.ylabel('ships')
plt.ylim(0, 2000)
plt.grid(True)
plt.show()

In [None]:
wout_ships = df[df['EncodedPixels'].isna()]
with_ships = df[df['EncodedPixels'].notna()]

reduced_wout_ships = wout_ships.sample(500) 

balanced_df = pd.concat([with_ships, reduced_wout_ships])
balanced_df

In [None]:
balanced_df['EncodedPixels'] = balanced_df['EncodedPixels'].astype(str)

In [None]:
balanced_grouped = balanced_df.groupby('ImageId')['EncodedPixels'].apply(lambda x: x.tolist()).reset_index()
balanced_grouped

In [None]:
grouped = balanced_grouped.groupby(balanced_grouped['EncodedPixels'].apply(lambda x: x == ['nan']))

nan_group = grouped.filter(lambda x: x['EncodedPixels'].iloc[0] == ['nan'])
not_nan_group = grouped.filter(lambda x: x['EncodedPixels'].iloc[0] != ['nan'])

reduced_not_nan_group = not_nan_group.sample(1500)

rebalanced_grouped = pd.concat([reduced_not_nan_group, nan_group])
rebalanced_grouped

In [None]:
class ShipDataset(Dataset):
    def __init__(self, df, image_dir, image_shape=(768,768), transform=None, preprocessing_fn=None):
        self.df = df
        self.image_dir = image_dir
        self.shape = image_shape
        self.transform = transform
        self.preprocessing_fn = preprocessing_fn
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image_name = self.df.iloc[idx]['ImageId']
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path)
        
        rles = self.df.iloc[idx]['EncodedPixels']
        mask = self.combine_rle_masks(rles, self.shape)

        if self.transform:
            image = self.transform(image)
        if self.preprocessing_fn:
            image = self.preprocessing_fn(np.array(image))
        
        return image, mask
    

    def rle_to_mask(self, rle, shape):
        """
        converting RLE string into a mask.

        Parameters:
        rle (str): rle string (format "start1 length1 start2 length2 ...")
        shape (tuple): mask shape (height, width)

        Returns:
        numpy.ndarray: mask as 2D numpy array
        """
            
        mask = np.zeros(shape[0]*shape[1], dtype=np.uint8)

        if rle == 'nan':
            return mask.reshape(shape)

        rle_nums = list(map(int, rle.split()))

        starts = rle_nums[0::2]
        lengths = rle_nums[1::2]

        starts = [start -1  for start in starts]

        for start, length in zip(starts, lengths):
            mask[start:start + length] = 1

        return mask.reshape(shape).T

    def combine_rle_masks(self, rles, shape):
        """
        Combines several RLE masks into one.

        Parameters:
        rles (list of str): list of rle strings
        shape (tuple): mask shape (height, width)

        Returns:
        torch.Tensor: combined mask as a 2D tensor
        """
        combined_mask = np.zeros(shape, dtype=np.uint8)

        for rle in rles:
            mask = self.rle_to_mask(rle, shape)
            combined_mask = np.maximum(combined_mask, mask)
            
        tensor_combined_mask = torch.tensor(combined_mask, dtype=torch.uint8)

        return tensor_combined_mask
    
   

In [None]:
visualize_dataset =  ShipDataset(df=rebalanced_grouped, image_dir=img_dir)

In [None]:
unique_image_ids = rebalanced_grouped['ImageId'].unique()

In [None]:
train_ids, test_ids = train_test_split(unique_image_ids, train_size=0.8, random_state=42)

train_df = rebalanced_grouped[rebalanced_grouped['ImageId'].isin(train_ids)]
test_df = rebalanced_grouped[rebalanced_grouped['ImageId'].isin(test_ids)]

train_grouped = train_df.groupby('ImageId')['EncodedPixels'].apply(lambda x: x.tolist()).reset_index()
test_grouped = test_df.groupby('ImageId')['EncodedPixels'].apply(lambda x: x.tolist()).reset_index()

train_grouped.columns = ['ImageId', 'EncodedPixels']
test_grouped.columns = ['ImageId', 'EncodedPixels']

In [None]:
#избавление от ненужной размерности
train_grouped['EncodedPixels'] = train_grouped['EncodedPixels'].apply(lambda x: [item for sublist in x for item in sublist])
test_grouped['EncodedPixels'] = test_grouped['EncodedPixels'].apply(lambda x: [item for sublist in x for item in sublist])

In [None]:
train_grouped['EncodedPixels'].iloc[6]

In [None]:
preprocess = get_preprocessing_fn('resnet18', pretrained='imagenet')

train_dataset = ShipDataset(df=train_grouped, image_dir=img_dir, preprocessing_fn=preprocess)
test_dataset = ShipDataset(df=test_grouped, image_dir=img_dir, preprocessing_fn=preprocess)


In [None]:
for i in range(10):
    image, mask = visualize_dataset[i]
    if isinstance(image, torch.Tensor):
        image = image.permute(1, 2, 0).numpy()
    if isinstance(mask, torch.Tensor):
        mask = mask.numpy()

    plt.figure(figsize=(10, 5))  
    plt.subplot(1, 2, 1)  
    plt.imshow(image)  
    plt.title("Image")  
    plt.axis("off")  
    
    plt.subplot(1, 2, 2)  
    plt.imshow(mask, cmap='gray')  
    plt.title("Mask")  
    plt.axis("off")  
    plt.show() 


In [None]:
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
model = Unet(encoder_name='resnet18', 
                      encoder_depth=5, 
                      encoder_weights='imagenet', 
                      decoder_channels=(256, 128, 64, 32, 16), 
                      decoder_use_batchnorm=True,
                      decoder_attention_type=None,
                      in_channels=3, 
                      classes=1, 
                      activation='sigmoid')

model = model.to(device)

for param in model.encoder.parameters():
    param.requires_grad = False

In [None]:
summary(model, input_size=(3, 768, 768))

In [None]:
# @title train function
import time

def train_segmentation(model, optimizer, loss_fn, train_dl, val_dl, epochs=20, device='cpu'):
    """
    Trains a segmentation model using Dice Loss and Recall as the evaluation metric.
    
    Parameters
    ----------
    model : nn.Module
        The segmentation model to train.
    optimizer : torch.optim.Optimizer
        The optimizer to use for training.
    loss_fn : nn.Module
        The loss function to use for training (DiceLoss).
    train_dl : DataLoader
        DataLoader for the training dataset.
    val_dl : DataLoader
        DataLoader for the validation dataset.
    epochs : int, optional
        Number of epochs to train the model. Default is 20.
    device : str, optional
        The device to use for training ('cpu' or 'cuda'). Default is 'cpu'.
    
    Returns
    -------
    dict
        Dictionary containing training and validation loss and recall for each epoch.
    """
    
    recall = torchmetrics.Recall(num_classes=1, threshold=0.5, task='binary').to(device)

    history = {
        'train_loss': [],
        'val_loss': [],
        'train_recall': [],
        'val_recall': []
    }
    
    start_time = time.time()
    
    for epoch in range(epochs):
        epoch_start_time = time.time()
        model.train()
        train_loss = 0.0
        train_recall = 0.0
        
        for batch in train_dl:
            images, masks = batch
            images, masks = images.permute(0, 3, 1, 2).to(torch.float32).to(device), masks.to(device)
            masks = masks.unsqueeze(1)
            optimizer.zero_grad()
            
            outputs = model(images)
            loss = loss_fn(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * images.size(0)
            train_recall += recall(outputs, masks.int()).item() * images.size(0)
        
        train_loss /= len(train_dl.dataset)
        train_recall /= len(train_dl.dataset)
        
        model.eval()
        val_loss = 0.0
        val_recall = 0.0
        
        with torch.no_grad():
            for batch in val_dl:
                images, masks = batch
                images, masks = images.permute(0, 3, 1, 2).to(torch.float32).to(device), masks.to(device)
                masks = masks.unsqueeze(1)
                outputs = model(images)
                loss = loss_fn(outputs, masks)
                
                val_loss += loss.item() * images.size(0)
                val_recall += recall(outputs, masks.int()).item() * images.size(0)
        
        val_loss /= len(val_dl.dataset)
        val_recall /= len(val_dl.dataset)
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_recall'].append(train_recall)
        history['val_recall'].append(val_recall)
        
        epoch_end_time = time.time()
        epoch_time = epoch_end_time - epoch_start_time
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Recall: {train_recall:.4f}, Val Recall: {val_recall:.4f}, Epoch Time: {epoch_time:.2f}")

    end_time = time.time()
    total_time = end_time - start_time
    print(f"Training completed in {total_time:.2f} sec")

    return history

In [None]:
loss_fn = DiceLoss(mode='binary', from_logits=False)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
history = train_segmentation(model, optimizer, loss_fn, train_loader, test_loader, epochs=15, device=device)

In [None]:
def plot_metric(history, name):

    plt.plot(history['train_'+name], label='train')
    plt.plot(history['val_'+name], label='val')
    plt.xlabel('Epoch')
    plt.ylabel(name)
    plt.title(f'{name.capitalize()} over Epochs')
    plt.legend()
    plt.show()


In [None]:
plot_metric(history, 'loss')

In [None]:
plot_metric(history, 'recall')

In [None]:
torch.save(model, 'ship_segmentation_model_2.pth')

In [None]:
loaded_model = torch.load('ship_segmentation_model_2.pth')

In [None]:
images_to_visualize.shape

In [None]:
images_batch, _ = next(iter(test_loader))

images_to_visualize = images_batch[:5]

loaded_model.eval()
with torch.no_grad():
    outputs = loaded_model(images_to_visualize.permute(0, 3, 1, 2).to(torch.float32).to(device))


for image, pred in zip(images_to_visualize, outputs):
    plt.figure(figsize=(10, 5))
    print(image.shape)

    plt.subplot(1, 2, 1)
    plt.imshow(image.numpy())
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(pred.squeeze().cpu(), cmap='gray')
    plt.title('Predicted Mask')
    plt.axis('off')
    
    plt.show()