In [38]:
import os
from os.path import join
import random
from glob import glob
from PIL import Image
from datetime import datetime
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import tqdm
import torch.nn as nn
import torch.nn.functional as F
import monai
from segment_anything import sam_model_registry
from segment_anything.predictor import SamPredictor
from segment_anything.modeling import ImageEncoderViT, MaskDecoder, PromptEncoder
from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator
import matplotlib.pyplot as plt

In [39]:
device='cuda' if torch.cuda.is_available() else 'cpu'
device='cpu'
model_path='./model/sam_vit_h_4b8939.pth'
sam_model=sam_model_registry['vit_h'](checkpoint=model_path).to(device)
predictor=SamPredictor(sam_model)

num2color={
    0:[0,0,0],
    1:[0, 153, 255],
    2:[102, 255, 153],
    3:[0, 204, 153],
    4:[255, 255, 102],
    5:[255, 255, 204],
    6:[255, 153, 0],
    7:[255, 102, 255],
    8:[102, 0, 51],
    9:[255, 204, 255],
    10:[255, 0, 102]
}
num2label={
    0:'background',
    1:'skin',
    2:'left eyebrow',
    3:'right eyebrow',
    4:'left eye',
    5:'right eye',
    6:'nose',
    7:'upper lip',
    8:'inner mouth',
    9:'lower lip',
    10:'hair'
}

In [40]:
class IG02DataSet(Dataset):
    
    def __init__(self,data_path,single=True):
        self.data_path=data_path
        self.data=sorted(glob(os.path.join(self.data_path,'*.image.png')))
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        datapath=self.data[index]
        

In [60]:
class LaPaDataset(Dataset):
    def __init__(self,data_path,label_path):
        self.data_path=data_path
        self.label_path=label_path

        self.data=sorted(glob(os.path.join(self.data_path,'*.jpg')))
        self.label=sorted(glob(os.path.join(self.label_path,'*.png')))

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        data_path=self.data[index]
        label_path=self.label[index]
        img=Image.open(data_path)
        img=img.resize((1024,1024))
        img=np.array(img)
        
        # img=torch.tensor(img,dtype=torch.float())
        label=Image.open(label_path)
        label=label.resize((1024,1024))
        label=np.array(label)
        label_list=[]
        # label=torch.tensor(label)
        center_list=[]
        bbox_list=[]
        embedding_list=[]
        for i in range(1,11):
            img=(img-img.min())/(img.max()-img.min())
            img=img.transpose((2,0,1))
            gt_mask=np.uint8(label==i)
            # if (gt_mask.any()==i)==False:
            #     center_list.append(None)
            #     bbox_list.append(None)
            #     continue
            y_idx,x_idx=np.where(gt_mask>0)
            y_min,y_max=np.min(y_idx),np.max(y_idx)
            x_min,x_max=np.min(x_idx),np.max(x_idx)
            dt_mask=cv2.distanceTransform(gt_mask[y_min:y_max+1,x_min:x_max+1],cv2.DIST_L2,3)
            local_coords=np.unravel_index(np.argmax(dt_mask,),dt_mask.shape)
            center_point=np.expand_dims(
                np.array([local_coords[1],local_coords[0]])+np.array([x_min,y_min]),axis=0)
            label_list.append(gt_mask)
            # label_list.append(torch.tensor(gt_mask))
            center_list.append(center_point)
            bbox_list.append(np.array([x_min,y_min,x_max,y_max]))
            
        img=torch.tensor(img).float()
        label_list=torch.tensor(label_list).long()
        center_list=torch.tensor(center_list)
        bbox_list=torch.tensor(bbox_list)

        return (img,label_list,center_list,bbox_list)
    
train_lapa_set=LaPaDataset(data_path='./LaPa/train/images',label_path='./LaPa/train/labels')
train_lapa_loader=DataLoader(train_lapa_set,batch_size=16,shuffle=True)

In [57]:
class finetuneSAM(nn.Module):
    def __init__(
            self,
            image_encoder:ImageEncoderViT,
            mask_decoder:MaskDecoder,
            prompt_encoder:PromptEncoder,

    ):
        super().__init__()
        self.image_encoder=image_encoder
        self.mask_decoder=mask_decoder
        self.prompt_encoder=prompt_encoder

        for param in self.image_encoder.parameters():
            param.requires_grad=False
        for param in self.prompt_encoder.parameters():
            param.requires_grad=False

    def forward(self,image:torch.Tensor,prompt:torch.Tensor,type:str):
        with torch.no_grad():
            image_embedding=self.image_encoder(image)
            if type=="single":
                label=torch.ones(size=prompt.shape[:-1],dtype=torch.long,device=prompt.device)
                sparse_embeddings,dense_embeddings=self.prompt_encoder(
                    points=(prompt,label),boxes=None,masks=None,
                )
            else:
                sparse_embeddings,dense_embeddings=self.prompt_encoder(
                    points=None,boxes=prompt,masks=None,
                )
        low_res_masks,_=self.mask_decoder(
            image_embeddings=image_embedding,
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False,
        )
        ori_res_masks=F.interpolate(low_res_masks,size=(image.shape[2],image.shape[3]),mode='bilinear',
                                      align_corners=False,)
        return ori_res_masks

In [30]:
# def train(dataloader,model,lr=1e-4,epochs=20,method='single'):

#     model.train()
#     seg_loss=monai.losses.DiceCELoss(sigmoid=True,squared_pred=True,reduction='mean')
#     cross_loss=nn.BCEWithLogitsLoss(reduction='mean')
#     optimizer=torch.optim.Adam(model.mask_decoder.parameters(),lr=lr,weight_decay=1e-4)
#     # for para in model.image_decoder.parameters():
#     #     para.required_grad=False
    
#     # for para in model.prompt_decoder.parameters():
#     #     para.required_grad=False
#     total_step=len(dataloader)
#     for epoch in range(epochs):
#         step=0
#         for batch in dataloader:
#             img,embedding_list,label_list,center_list,bbox_list=batch
#             step+=1
#             optimizer.zero_grad()
#             print(embedding_list.shape)
#             for i in range(10):
#                 embedding=embedding_list[:,i,:]
#                 labels=label_list[i]
#                 center_points=center_list[i]
                
#                 box=bbox_list[:,i,:]
#                 box=box[:,None,:]
#                 center=center_list[:,i,:]
#                 if method=='box':
#                     sparse_embeddings,dense_embeddings=model.prompt_decoder(
#                         points=None,
#                         boxer=box,
#                         masks=None
#                     )
#                 elif method=='single':
#                     sparse_embeddings,dense_embeddings=model.prompt_decoder(
#                         points=center,boxer=None,masks=None
#                     )
#                 mask_pred,_=model.mask_decoder(
#                     image_embeddings=embedding,
#                     image_pe=model.prompt_decoder.get_dense_pe(),
#                     sparse_propt_embeddings=sparse_embeddings,
#                     dense_prompt_embeddings=dense_embeddings,
#                     multimask_output=False
#                 )
#                 loss1=seg_loss(seg,label)
#                 loss2=cross_loss(seg,label)
#                 loss=loss1+loss2
#                 loss.backward()
#                 optimizer.step()
#                 if int(100*i/step)%10==0:
#                     print('epoch[{}/{}],step [{}/{}],loss:{:.4f}'.format(epoch+1,epochs,i+1,total_step,loss.cpu().item()))
                    

In [58]:
def train(model,output_path='./res/finetune'
          ,data_path='./LaPa/train/images',label_path='./LaPa/train/labels'):
    os.makedirs(output_path, exist_ok=True)
    sam_model = model
    finetune_model = finetuneSAM(
        image_encoder=sam_model.image_encoder,
        mask_decoder=sam_model.mask_decoder,
        prompt_encoder=sam_model.prompt_encoder,
    ).to(device)
    finetune_model.train()

    optimizer=torch.optim.Adam(finetune_model.mask_decoder.parameters(),lr=1e-4,weight_decay=0.01)
    seg_loss=monai.losses.DiceLoss(sigmoid=True,squared_pred=True,reduction="mean")
    ce_loss=nn.BCEWithLogitsLoss(reduction='mean')

    num_epochs=10
    losses=[]
    best_loss=1e10

    train_dataset=LaPaDataset(data_path=data_path,label_path=label_path)
    train_loader=DataLoader(train_dataset,batch_size=16,shuffle=True)
    # val_dataset=LaPaDataset(data_path=)
    # val_loader=DataLoader(val_dataset, batch_size=16,shuffle=False)

    start_epoch=0

    for epoch in range(start_epoch, num_epochs):
        epoch_loss=0
        for img,label_list,center_list,box_list in train_loader:
            optimizer.zero_grad()
            
            for i in range(0,10):
                img=img.to(device)
                box=box_list[:,i,:].to(device)
                center_point=center_list[:,i,:].to(device)
                label=label_list[:,i,:].to(device)
                
                finetune_pred=finetune_model(img,center_point,'single')
                loss=seg_loss(finetune_pred,label)+ce_loss(finetune_pred,label)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                epoch_loss += loss.item()
                
                finetune_pred=finetune_model(img,box,'box')
                loss=seg_loss(finetune_pred,label)+ce_loss(finetune_pred,label)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                epoch_loss += loss.item()
                
                

        epoch_loss /= len(train_loader) * 3
        losses.append(epoch_loss)

        print(
            f'Time: {datetime.now().strftime("%Y%m%d-%H%M")}, Epoch: {epoch}, Loss: {epoch_loss}'
        )

        checkpoint = {
            'model': finetune_model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
            'loss': epoch_loss,
        }
        torch.save(checkpoint, join(output_path, 'finetune_latest.pth'))

        if epoch_loss < best_loss:
            best_loss = epoch_loss
            checkpoint = {
                'model': finetune_model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'loss': epoch_loss,
            }
            torch.save(checkpoint, join(output_path, 'finetune_best.pth'))

        plt.plot(losses)
        plt.title("Dice + Cross Entropy Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.show()
        plt.savefig(join(output_path, f'finetune_loss.png'))
        plt.close()

In [61]:
train(sam_model)

KeyboardInterrupt: 