In [1]:
%%writefile script/get_data.py

from warnings import filterwarnings
from pandas import read_csv, DataFrame
from sklearn.model_selection import train_test_split
from typing import Tuple
from pathlib import Path

filterwarnings('ignore')

def data(ROOT:Path) -> Tuple[DataFrame, DataFrame]:
    
    #getting data

    train_df = read_csv(ROOT/'train.csv', index_col='id', usecols=[0])
    depths_df = read_csv(ROOT/'depths.csv', index_col='id')
    train_df = train_df.join(depths_df)
    valid_df = depths_df[~depths_df.index.isin(train_df.index)]

    train_df['images'] = [ROOT/f'train/images/{idx}.png' for idx in train_df.index]
    train_df['masks'] = [ROOT/f'train/masks/{idx}.png' for idx in train_df.index]

    #splitting into train and test data
    trainset, testset = train_test_split(train_df, test_size=0.2, random_state=32)
    
    return trainset, testset

Overwriting script/get_data.py


In [2]:
%%writefile script/dataloader.py

from pathlib import Path
from os import cpu_count
from typing import Tuple
from get_data import data
from segdataset import SaltSegmentationDataset
from augmentation import train_augs, test_augs

from torch.utils.data import DataLoader

import torch

def create_loaders(DATA_PATH:Path, batch_size:int) -> Tuple[DataLoader, DataLoader]:
    
    traindata, testdata = data(DATA_PATH)
    
    #load dataset and apply augmentation techniques
    trainset = SaltSegmentationDataset(traindata.images, traindata.masks, train_augs())
    testset = SaltSegmentationDataset(testdata.images, testdata.masks, test_augs())

    #WRAP AN iterable dataloader around the train images and masks
    trainloader = DataLoader(trainset, batch_size=batch_size, 
                             pin_memory=True if torch.cuda.is_available() else False , shuffle=True, num_workers=cpu_count())
    testloader = DataLoader(testset, batch_size=batch_size, 
                            pin_memory=True if torch.cuda.is_available() else False , num_workers=cpu_count())
    
    return trainloader, testloader

Overwriting script/dataloader.py


In [3]:
%%writefile script/augmentation.py

import albumentations as A

def train_augs(prob:int=1) -> A.Compose:
    return A.Compose([
            A.Resize(128, 128),
            A.PadIfNeeded(min_height=128, min_width=128, p=prob),
            A.CropNonEmptyMaskIfExists(width=128, height=128, p=prob),
            A.Superpixels(max_size=128),
#             A.Normalize(),
        ])

def test_augs() -> A.Compose:
    return A.Compose([
            A.Resize(128, 128),
#             A.Normalize()
        ])

Overwriting script/augmentation.py


In [4]:
%%writefile script/segdataset.py

from torchvision import io
from torch.utils.data import Dataset
from typing import Tuple
from pathlib import Path
import torch
import cv2

class SaltSegmentationDataset(Dataset):
    def __init__(self, imagePaths:Path, maskPaths:Path, augmentations) -> Tuple[torch.Tensor, torch.Tensor]:
        self.imagePaths = imagePaths
        self.maskPaths = maskPaths
        self.augmentations = augmentations
#         self.train = train
        
        
    def __len__(self):
        return len(self.imagePaths)
    
    def __getitem__(self, idx):
        
        image = io.read_image(str(self.imagePaths[idx]), mode=io.ImageReadMode.GRAY).permute(1, 2, 0).numpy()
        mask = torch.tensor(cv2.cvtColor(cv2.imread(str(self.maskPaths[idx]), 1), cv2.COLOR_BGR2GRAY)).unsqueeze(2).numpy()
        
#         augment = DataFrame()
        
        augment = self.augmentations(image=image, mask=mask)            
        image = augment['image']
        mask = augment['mask']
        
        image = torch.Tensor(image).type(torch.float32)/255.0
        mask = torch.Tensor(mask).type(torch.float32)/255.0
        
        return image, mask

Overwriting script/segdataset.py


In [5]:
%%writefile script/architecture.py

from torch import nn
from segmentation_models_pytorch.losses import DiceLoss
import segmentation_models_pytorch as smp

class SaltSegmentationModel(nn.Module):
    
    def __init__(self):
        super(SaltSegmentationModel, self).__init__()
        
        self.arc = smp.Unet(
            encoder_name='timm-efficientnet-b0',
            encoder_weights='imagenet',
            in_channels=1,
            classes=1,
            activation=None
        )
        
    def forward(self, images, masks=None):
        logits = self.arc(images)
        
        if masks != None:
            dice_loss = DiceLoss(mode='binary')(logits, masks)
            bce_loss = nn.BCEWithLogitsLoss()(logits, masks)
            
            return logits, dice_loss + bce_loss
    
        return logits

Overwriting script/architecture.py


In [6]:
%%writefile script/engine.py

from typing import Tuple, List
from numpy import Inf
from pathlib import Path
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch import device
from tqdm.auto import tqdm

from script.architecture import SaltSegmentationModel

import torch

def train_func(dataloader:DataLoader, model:SaltSegmentationModel, optimizer:Adam) -> float:
    
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    train_loss = 0.0
    
    model.train()
    
    for images, masks in tqdm(dataloader):

        images = images.permute(0, 3, 1, 2); masks = masks.permute(0, 3, 1, 2)
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        
        optimizer.zero_grad()
        logits, loss = model(images, masks)
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
    return train_loss / len(dataloader)


def test_func(dataloader:DataLoader, model:SaltSegmentationModel) -> float:
    
    model.eval()
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    test_loss = 0.0
    with torch.inference_mode():
        for images, masks in tqdm(dataloader):
            images = images.permute(0, 3, 1, 2); masks = masks.permute(0, 3, 1, 2)
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            
            logits, loss = model(images, masks)
            
            test_loss += loss.item()
            
    return test_loss / len(dataloader)


def train(trainloader:DataLoader, testloader:DataLoader, model:SaltSegmentationModel,
          optimizer:Adam, EPOCHS:int, DIR:Path) -> Tuple[List, List]:
    
    best_valid_loss = Inf

    print("*"*60)
    print('                         START TRAINING               ')
    print("*"*60)
    train_losses, test_losses = list(), list()
    for i in tqdm(range(EPOCHS)):
        train_loss = train_func(trainloader, model, optimizer)
        test_loss = test_func(testloader, model)

        train_losses.append(train_loss)
        test_losses.append(test_loss)

        if test_loss < best_valid_loss:
            torch.save(model.state_dict(), DIR/'model/best_salt_model.pt')
            print('SAVED MODEL')

            best_valid_loss = test_loss

        print(f'Epoch : {i+1} : Training loss : {train_loss} | Validation loss : {test_loss}')

    print("*"*60)
    print('                         TRAINING ENDS               ')
    print("*"*60)
    
    return train_losses, test_losses

Overwriting script/engine.py


In [7]:
%%writefile script/train.py

import os

from argparse import ArgumentParser
from pathlib import Path
from script.dataloader import create_loaders
from script.architecture import SaltSegmentationModel
from script.engine import train
from torch.optim import Adam
from matplotlib import pyplot as plt

import torch

parser = ArgumentParser(description='Get some hyperparameters', add_help=True)

parser.add_argument('-e', '--num_epochs', default=5, metavar='EPOCH', type=int, help='number of epochs')
parser.add_argument('-bs', '--batch_size', default=32, type=int, help='number of batch or sample size')
parser.add_argument('-lr', '--learning_rate', default=0.001, type=float, metavar='LR', help='learning rate')
parser.add_argument('--data-dir', default=Path('./'), help='directory of training data', type=Path)

args = parser.parse_args()

#hyperparameters
NUM_EPOCHS = args.num_epochs
BATCH_SIZE = args.batch_size
LR = args.learning_rate

#data directory
DIR = args.data_dir
print(f'Data file path : {DIR}')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

trainloader, testloader = create_loaders(DIR, BATCH_SIZE)

model = SaltSegmentationModel().to(device)

optimizer = Adam(model.parameters(), lr=LR)

def plot_loss():
    train_losses, test_losses = train(trainloader, testloader, model, optimizer=optimizer, EPOCHS=NUM_EPOCHS, DIR=DIR)

    # plot the training loss
    plt.style.use("ggplot")
    plt.figure()
    plt.plot(train_losses, label="train_loss")
    plt.plot(test_losses, label="test_loss")
    plt.title("Training Loss on Dataset")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend(loc="upper right")

    plt.savefig(DIR/'pictures/losses.png')

if __name__ == '__main__':
    plot_loss()

Overwriting script/train.py


In [8]:
%%writefile script/predict.py

import torch
from torchvision import io
from argparse import ArgumentParser
from pathlib import Path
from script.architecture import SaltSegmentationModel
from torchvision.transforms import Resize
from matplotlib import pyplot as plt

parser = ArgumentParser()

parser.add_argument('--image_path', type=str, help='Image to predict mask')

parser.add_argument('--model_path', 
                    default=Path('./model/best_salt_model.pt'), type=str,
                    help='Model Path')

parser.add_argument('--image_name', type=str, help='Image to predict mask')

args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def load_model(filepath=args.model_path):
  
    model = SaltSegmentationModel().to(device)                             
    model.load_state_dict(torch.load(filepath))

    return model

def predict_mask(image_path=args.image_path, filepath=args.model_path):

    model = load_model(filepath)

    image = io.read_image(str(image_path), mode=io.ImageReadMode.GRAY).numpy()#.permute(1, 2, 0).numpy()

    image = torch.Tensor(image).type(torch.float32)/255.0

    # Resize the image to be the same size as the model
    transform = Resize(size=(128, 128))
    image = transform(image)#.squeeze(dim=0) 

    # Predict on image
    model.eval()
    with torch.inference_mode():

        image = image.unsqueeze(1).to(device)
        model.to(device)

        logits = model(image)
        pred_mask = (torch.sigmoid(logits).type(torch.float32)) > 0.5 * 1.0

    return image, pred_mask

def plot_mask():

    image, pred_mask = predict_mask(image_path=args.image_path, filepath=args.model_path)

    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5))

    ax1.set_title('IMAGE')
    ax1.imshow(image.detach().cpu().squeeze(), cmap='seismic_r')

    ax2.set_title('PREDICTED')
    ax2.imshow(pred_mask.detach().cpu().squeeze(), cmap='jet')
    
    plt.savefig('./pictures/'+args.image_name+'.png')

if __name__ == '__main__':
    plot_mask()
     

Overwriting script/predict.py


In [23]:
!python train.py --num_epochs 5 --batch_size 64 --learning_rate 0.001

Data file path : /content/drive/My Drive/Salt Dataset
************************************************************
                         START TRAINING               
************************************************************
  0% 0/5 [00:00<?, ?it/s]
  0% 0/50 [00:00<?, ?it/s][A
  2% 1/50 [00:08<06:46,  8.30s/it][A
  4% 2/50 [00:08<02:55,  3.66s/it][A
  6% 3/50 [00:09<01:41,  2.16s/it][A
  8% 4/50 [00:09<01:08,  1.48s/it][A
 10% 5/50 [00:09<00:49,  1.10s/it][A
 12% 6/50 [00:10<00:37,  1.16it/s][A
 14% 7/50 [00:11<00:34,  1.25it/s][A
 16% 8/50 [00:12<00:36,  1.15it/s][A
 18% 9/50 [00:12<00:33,  1.22it/s][A
 20% 10/50 [00:14<00:44,  1.11s/it][A
 22% 11/50 [00:15<00:38,  1.02it/s][A
 24% 12/50 [00:15<00:33,  1.14it/s][A
 26% 13/50 [00:16<00:33,  1.10it/s][A
 28% 14/50 [00:17<00:31,  1.13it/s][A
 30% 15/50 [00:18<00:30,  1.15it/s][A
 32% 16/50 [00:19<00:31,  1.10it/s][A
 34% 17/50 [00:20<00:30,  1.07it/s][A
 36% 18/50 [00:21<00:28,  1.11it/s][A
 38% 19/50 [00:22<00

In [56]:
!python predict.py --image_path './images/739a2484b4.png' --image_name 'Mask5'