In [None]:
import RS_utils
import RS_dataset
import RS_models
#---
import datetime
import logging
import numpy as np 
from glob import glob
import os
import torch
import matplotlib.pyplot as plt 
from torch.utils.data import DataLoader
#---
import torch.nn as nn 
#---
from lightning.fabric import Fabric
import lightning as L
import segmentation_models_pytorch as smp

#-- data
img_path = "/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/01.512_imgs"
mask_path = "/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/01.Models/04.SAM_fine/0.data/02.512_masks"

img_path_ship  = np.array(sorted(glob(os.path.join(img_path, "*.png"))) )
mask_path_ship = np.array(sorted(glob(os.path.join(mask_path, "*.png"))) )

aa = np.load("/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/05.Training/Segmentation/03.data_list/512_ships.npy")

selected_paths_img = img_path_ship[aa]
selected_paths_mask  = mask_path_ship[aa]


#-- args
TASK = "SHIP"
MODEL_NAME = "UNET_PP"
EXEC_VER = 35 
BATCH_SIZE = 4
DEVICE = "cuda:0"
DEVICES = [0,1,2,3]
RESUME = False
SAVE_EPOCH = 20


#-- category 
ISAID_CLASSES_SHIP = (
    'background','ship','harbor' 
    )
ISAID_PALETTE_SHIP = {
    0: (0, 0, 0), 
    1: (0, 0, 63), 
    2: (0, 100, 155)}

#--- logger
# Set up logging
log_filename = datetime.datetime.now().strftime(f'./01.log/ver_{EXEC_VER}_%Y-%m-%d_%H-%M-%S.log')
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
handler = logging.FileHandler(log_filename)
logger.addHandler(handler)

#-- dataset
train_dataset = RS_dataset.Seg_RS_dataset_ship(img_dir=selected_paths_img, mask_dir=selected_paths_mask, image_resize = None, phase="train",palette=ISAID_PALETTE_SHIP )
dataloader  = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=train_dataset.collate_fn)

#--- model 
#model = RS_models.Edge_Net()
model = smp.UnetPlusPlus(encoder_name="resnet152",classes=3)
#model = model.to(DEVICE)
#criterion = nn.CrossEntropyLoss(reduction="mean") 
#criterion = nn.BCELoss()
criterion = RS_dataset.UNet_metric(num_classes=3)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

#-- resume
if RESUME == True:
    tgt_path = "/mnt/hdd/eric/.tmp_ipy/15.Lab_Detection/05.Training/Segmentation/02.ckpts"
    ckpt_path = os.path.join( tgt_path, sorted(os.listdir(tgt_path))[-1]  )
    print("resume chekcpoint : ",ckpt_path)

    checkpoint = torch.load(ckpt_path)
    model.load_state_dict(checkpoint)

#--- fabric setup 



In [None]:



class CombinedModel(nn.Module):
    def __init__(self, model1, model2):
        super(CombinedModel, self).__init__()
        self.model1 = model1
        self.model2 = model2
        
        # edge_net freeze
        for param in self.model2.parameters():
            param.requires_grad = False           

    # def save_pretrained(self,path):
    #     torch.save(model.state_dict(), PATH)

    def save_pretrained(self, path):
        # Save the model
        self.save_pretrained(path)
    

    
    def forward(self, batch ):

        #---------

        
        outputs = self.model1(**batch)
            
        pred = outputs.preds[:,:, 448:]
        # resize pred => 512
        pred = F.interpolate(pred,(512,512),mode='nearest')
        pred = pred.float()
        #print("pred.shape : ", pred.shape)
        
        #perceptual loss from edge_net
        layer_1_out,layer_2_out,layer_3_out = self.model2(pred)
        layer_1_gt ,layer_2_gt ,layer_3_gt  = self.model2(labels)
        loss_1 = torch.nn.functional.l1_loss(layer_1_out, layer_1_gt)
        loss_2 = torch.nn.functional.l1_loss(layer_2_out, layer_2_gt)
        loss_3 = torch.nn.functional.l1_loss(layer_3_out, layer_3_gt)

        #--- loss 
        loss_seg = outputs.loss
        loss_percept = loss_1 + loss_2 + loss_3
        
        return loss_seg,loss_percept