In [None]:
import model_torch
import modules_torch
import _utils_torch
from Preprocessing import *
from _utils_torch import *
import loss

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

print(torch.cuda.is_available())

In [None]:
data_dir = '../Data'
data_list = os.listdir(data_dir)

S1_list = []
S2_list = []
for item in data_list:
    if 'S1' in item:
        S1_list.append(item)
    elif 'S2' in item:
        S2_list.append(item)

In [None]:
train_data_list = S1_list[:5] + S2_list[:5]
test_data_list = S1_list[5:] + S2_list[5:]

In [None]:
X_train, Y_train, X_test, Y_test = preprocessing(train_data_list, test_data_list)

In [None]:
print(X_train.shape, X_train.min(), X_train.max(), X_train.dtype)
print(Y_train.shape, Y_train.min(), Y_train.max(), Y_train.dtype)
print(X_test.shape, X_test.min(), X_test.max(), X_test.dtype)
print(Y_test.shape, Y_test.min(), Y_test.max(), Y_test.dtype)

In [None]:
h_params=EasyDict()
h_params.gpu_num=0
h_params.seed=42

h_params.total_epoch=100
h_params.batch_size=32
h_params.lr=1e-3
h_params.lr_schedule_patience=3
h_params.earlystop_patience=5

h_params.model_name = "./output/model.pth"
h_params.model_save_base = os.path.join(os.getcwd(),"./output/train_log")

now = time.localtime(time.time())
h_params.trial_ = f"{now.tm_year}{now.tm_mon:02d}{now.tm_mday:02d}_{now.tm_hour:02d}{now.tm_min:02d}"
h_params.trial_path =os.path.join(h_params.model_save_base,h_params.trial_)
h_params.model_save_path = os.path.join(h_params.trial_path,h_params.model_name)

In [None]:
os.makedirs(h_params.trial_path, exist_ok=True)
device = torch.device(f"cuda:{h_params.gpu_num}" if torch.cuda.is_available() else 'cpu')

In [None]:
train_transform = A.Compose([
    A.ShiftScaleRotate(shift_limit=(-.01, .01),
                       scale_limit=(-.01, .01),
                       rotate_limit=(-1, 1),p=0.01),
    A.RandomBrightnessContrast(brightness_limit=0.001,
                               contrast_limit=0.001,
                               p=0.01),
    
    A.Resize(width=512, height=512),
    ToTensorV2(),

], p=1)

val_transform = A.Compose([
    A.Resize(width=512, height=512),
    ToTensorV2(),
], p=1)

In [None]:
train_dataset = custom_dataset(img=X_train,
                               mask=Y_train,
                               transform=train_transform)

train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=h_params.batch_size,
                              shuffle=True)

In [None]:
for idx in range(150,200):
    sample = train_dataset[idx]
    sam_img = sample['image']
    sam_mask = sample['mask']

    if sam_mask.max() < 1:
        continue
    
    print(sam_img.shape, sam_img.min(), sam_img.max(), sam_img.dtype)
    print(sam_mask.shape, sam_mask.min(), sam_mask.max(), sam_mask.dtype)

    plt.figure(figsize=(7,7))
    plt.subplot(131)
    plt.imshow(sam_img[0,...], cmap='gray')
    plt.axis('off')
    plt.subplot(132)
    plt.imshow(sam_mask, cmap='gray')
    plt.axis('off')
    plt.subplot(133)
    plt.imshow(sam_img[0,...], cmap='gray')
    plt.imshow(sam_mask, cmap='jet', alpha=.5)
    plt.axis('off')
    plt.show()

In [None]:
model = model_torch.Unet(
    encoder_name='resnet34',
    encoder_depth=5,
    encoder_weights='imagenet',
    in_channels=1, 
    classes=1,
    activation=None,
    )

model = model.to(device)

In [None]:
optim = torch.optim.Adam(params=model.parameters(),lr = h_params.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,patience=h_params.lr_schedule_patience)

criterion_dice = loss.DiceLoss(mode='binary', smooth=1e-07)
criterion_focal = loss.FocalLoss(mode='binary')

monitor = EarlyStopping(patience=h_params.earlystop_patience,
                        verbose=True,
                        delta=0,
                        path=h_params.model_name,
                        trace_func=print)

In [None]:
def train_show(image,target,pred):
    plt.figure(figsize=(8,8))
    plt.subplot(131)
    plt.imshow(image.cpu().detach().numpy()[0,0,...], cmap='gray')
    plt.axis('off')
    plt.subplot(132)
    plt.imshow(image.cpu().detach().numpy()[0,0,...], cmap='gray')
    plt.imshow(target.cpu().detach().numpy()[0,...], cmap='jet', alpha=.5)
    plt.axis('off')
    plt.subplot(133)
    plt.imshow(image.cpu().detach().numpy()[0,0,...], cmap='gray')
    plt.imshow(pred.cpu().detach().numpy()[0,0,...], cmap='jet', alpha=.5)
    plt.axis('off')
    plt.show()

In [None]:
metric_logger = {k:[] for k in ['train_loss','train_dice','lr']}

total_train_num = len(train_dataloader.sampler)

for epoch in range(h_params.total_epoch):
    for param in optim.param_groups:
        lr_stauts = param['lr']
    metric_logger['lr'].append(lr_stauts)
    epoch_loss = {k:0 for k in metric_logger if k not in ['lr']}
    model.train()

    for batch_idx, data in enumerate(tqdm(train_dataloader, total=len(train_dataloader), position=0, desc='Train', colour='blue')):
        batch_num = len(data['image'])
        
        image = data['image'].to(device, dtype=torch.float)
        target = data['mask'].to(device, dtype=torch.long)
        
        pred = model(image)
        
        focal_loss = criterion_focal(pred, target)
        dice_loss = criterion_dice(pred, target)
        loss = focal_loss + dice_loss
        
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        if batch_idx ==0:
            train_show(image,target,pred.sigmoid())
        
        epoch_loss['train_dice'] += dice_loss.item()*batch_num
        epoch_loss['train_loss'] += loss.item()*batch_num

    epoch_loss = {k:(v/total_train_num if 'train' in k else v/total_val_num) for k,v in epoch_loss.items()}

    monitor(epoch_loss['train_loss'], model)
    if monitor.early_stop:
        print(f"Train early stopped, Minimum validation loss: {monitor.val_loss_min}")
        break
    
    scheduler.step(epoch_loss['train_loss'])

    print(f"Epoch {epoch+1:03d}/{h_params.total_epoch:03d}\tLR: {lr_stauts:.0e}\n\
          Train_loss: {epoch_loss['train_loss']:.7f}\tTrain_dice_loss-> {epoch_loss['train_dice']}")