In [23]:
%load_ext autoreload
%autoreload 2

import RS_utils
import RS_dataset
import RS_models
#---
import datetime
import logging
import numpy as np 
from glob import glob
import os
import torch
import matplotlib.pyplot as plt 
from torch.utils.data import DataLoader
#---
import torch.nn as nn 
#---
from lightning.fabric import Fabric
import lightning as L
import segmentation_models_pytorch as smp


In [26]:

#-- data
img_path = "/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/01.512_imgs"
mask_path = "/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/02.512_masks"

img_path_ship  = np.array(sorted(glob(os.path.join(img_path, "*.png"))) )
mask_path_ship = np.array(sorted(glob(os.path.join(mask_path, "*.png"))) )

aa = np.load("/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/05.Training/Segmentation/03.data_list/512_ships.npy")

selected_paths_img = img_path_ship[aa]
selected_paths_mask  = mask_path_ship[aa]


#-- args
TASK = "SHIP"
MODEL_NAME = "UNET_PP_EDGE"
EXEC_VER = 40 
BATCH_SIZE = 4
DEVICE = "cuda:0"
DEVICES = [0,1,2,3]
RESUME = False
SAVE_EPOCH = 20


#-- category 
ISAID_CLASSES_SHIP = (
    'background','ship','harbor' 
    )
ISAID_PALETTE_SHIP = {
    0: (0, 0, 0), 
    1: (0, 0, 63), 
    2: (0, 100, 155)}

#--- logger
# Set up logging
log_filename = datetime.datetime.now().strftime(f'./01.log/ver_{EXEC_VER}_%Y-%m-%d_%H-%M-%S.log')
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
handler = logging.FileHandler(log_filename)
logger.addHandler(handler)

#-- dataset
train_dataset = RS_dataset.Seg_RS_dataset_SegEdge(img_dir=selected_paths_img, mask_dir=selected_paths_mask, image_resize = None, phase="train",palette=ISAID_PALETTE_SHIP )
dataloader  = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=train_dataset.collate_fn)

#--- model 
#model = RS_models.Edge_Net()
model_1 = smp.UnetPlusPlus(encoder_name="resnet152",classes=3)
tgt_path = "/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/07.Unet_PlusPlus_EdgeNet/02.ckpts/ver_35_unet_epoch_101_iteration_470256.pt"
checkpoint = torch.load(tgt_path)
model_1.load_state_dict(checkpoint)

model_2 = RS_models.Edge_Net()
tgt_path = "/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/07.Unet_PlusPlus_EdgeNet/02.ckpts/ver_31_edgenet_epoch_101.pt"
checkpoint = torch.load(tgt_path)
model_2.load_state_dict(checkpoint)

<All keys matched successfully>

In [2]:
model_2

Edge_Net(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu): ReLU()
  (softmax): Softmax(dim=1)
)

In [93]:
class CombinedModel(nn.Module):
    def __init__(self, model1, model2):
        super(CombinedModel, self).__init__()
        self.model1 = model1
        self.model2 = model2
        
        # seg_net train
        for param in self.model1.parameters():
            param.requires_grad = True       
        
        
        # edge_net freeze
        for param in self.model2.parameters():
            param.requires_grad = False           

    # def save_pretrained(self,path):
    #     torch.save(model.state_dict(), PATH)

    def save_pretrained(self, path):
        # Save the model
        self.save_pretrained(path)

    
    def forward(self, img,mask ):

        #---------       
        
        outputs = self.model1(img)
        
        #perceptual loss from edge_net
        layer_1_out,layer_2_out,layer_3_out = self.model2(outputs)
        
        layer_1_gt ,layer_2_gt ,layer_3_gt  = self.model2(mask)
        
        loss_1 = torch.nn.functional.l1_loss(layer_1_out, layer_1_gt,reduction='mean')
        loss_2 = torch.nn.functional.l1_loss(layer_2_out, layer_2_gt,reduction='mean')
        loss_3 = torch.nn.functional.l1_loss(layer_3_out, layer_3_gt,reduction='mean')

        #--- loss 
        
        loss_percept = loss_1 + loss_2 + loss_3
        
        return outputs,loss_percept

In [94]:
combine_net = CombinedModel(model_1, model_2)

In [95]:
batch = next(iter(dataloader))

In [96]:
img, mask, edge = batch[0], batch[1],batch[2]

In [97]:
model_1 = model_1.eval()
pred_1 = model_1(img)

In [98]:
pred_1.shape

torch.Size([4, 3, 256, 256])

In [99]:
mask.shape

torch.Size([4, 3, 256, 256])

In [100]:
out_ = model_2(mask)

In [101]:
out_ = model_2(mask)

In [112]:
outputs_ , loss_percept = combine_net(img,mask)

In [117]:
outputs_.shape

torch.Size([4, 3, 256, 256])

In [115]:
loss_percept.item()

45.868900299072266

In [116]:
outputs_[1].item()