In [8]:
import torch 
import pandas as pd
from torch.utils.data import DataLoader

from segmentation.config import CFG
from segmentation.models.unet import unet
from segmentation.scr.utils import losses, transforms
from segmentation.scr.utils.utils import set_seed
from segmentation.scr.tilling_dataset import Tilling_Dataset
from segmentation.scr.train_function import train_model

from colorama import Fore, Style
c_  = Fore.GREEN
sr_ = Style.RESET_ALL

#from segmentation.scr
pd.options.mode.chained_assignment = None

In [9]:
train_transform  = transforms.get_transform(transform_type='train')
val_transform = transforms.get_transform(transform_type='val')

In [10]:
train_dataset = Tilling_Dataset(
    name_data='kidney_1_tilling',
    path_to_df=CFG.path_df_kidney_1_till,
    use_random_sub=True,
    empty_tile_pct=0,
    sample_limit=8000,
    random_seed=CFG.random_seed,
    transform=train_transform 
)


Dataset contains 20744 empty and 33952 non-empty tiles.
Sample 0 empty and 8000 non-empty tiles.


In [11]:
val_dataset = Tilling_Dataset(
    name_data='kidney_3_tilling',
    path_to_df=CFG.path_df_kidney_3_till,
    use_random_sub=True,
    empty_tile_pct=6,
    sample_limit=2000,
    random_seed=CFG.random_seed,
    transform=val_transform
    )

Dataset contains 14595 empty and 6447 non-empty tiles.
Sample 120 empty and 1880 non-empty tiles.


In [12]:
set_seed(CFG.random_seed)

train_loader = DataLoader(train_dataset, batch_size=CFG.train_batch_size, num_workers=2, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=CFG.valid_batch_size, num_workers=2, shuffle=False, pin_memory=True)

In [13]:
model = unet.AttU_Net(n_channels=3, n_classes=1, bilinear=True).to(CFG.device)
num_epoch = 60
loss_fn = losses.BCE_DICE(mode="SUM")
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) 
sheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer=optimizer, mode='max', patience=4, factor=0.6, verbose=True, threshold=1e-3)
device = CFG.device

In [14]:
train_model(model=model, 
                optimizer=optimizer, 
                loss_func=loss_fn, 
                train_loader=train_loader, 
                val_loader=val_loader, 
                num_epochs=num_epoch, 
                scheduler=sheduler, 
                device =CFG.device,
                path_to_save=CFG.path_to_save_state_model)


Epoch 1/60
----------


Train :   1%|          | 24/2000 [00:09<12:53,  2.55it/s, epoch=24, gpu_mem=3.73 GB, lr=0.00030, train_loss=0.0917]


KeyboardInterrupt: 