In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import cv2
import glob
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import glob
from PIL import Image
import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from tqdm.notebook import tqdm
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torch.autograd import Variable

In [None]:
image_transform = transforms.Compose([transforms.Resize((256, 256), 2),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) # imagenet stats
        
mask_transform = transforms.Compose([transforms.Resize((256, 256), 2),
                                    transforms.CenterCrop(224),
                                    transforms.ToTensor()])

In [None]:
class CelebAHQ(Dataset):
    def __init__(self, parent_dir, image_dir, mask_dir, image_transform, mask_transform):

        self.mask_list = glob.glob(parent_dir+'/'+mask_dir+'/*')
        self.mask_list.sort()
        self.image_list = []

        # an image exists for every mask
        for path in self.mask_list:
            self.image_list.append(path.replace('.png', '.jpg').replace(mask_dir, image_dir))
        self.mask_list = self.mask_list

        self.image_transform = image_transform
        self.mask_transform = mask_transform
                
    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        img = Image.open(self.image_list[index]).convert('RGB')
        img = image_transform(img)
        
        mask = Image.open(self.mask_list[index]).convert('L')
        mask_present = mask_transform(mask)

        mask_present = mask_present.type(torch.BoolTensor)
        mask_not_present = torch.bitwise_not(mask_present)
        mask = torch.cat([mask_not_present, mask_present], dim=0)
        
        return img, mask

In [None]:
dataset = CelebAHQ('..', 'input/makeup-lips-segmentation-28k-samples/set-lipstick-original/720p/', 'input/makeup-lips-segmentation-28k-samples/set-lipstick-original/720p/', image_transform, mask_transform)

In [None]:
len(dataset)

In [None]:
def train_val_test_split(dataset):
    train_dataset = torch.utils.data.Subset(dataset, range(0, int(0.8 * len(dataset))))
    val_dataset = torch.utils.data.Subset(dataset, range(int(0.8*len(dataset)), int(0.9*len(dataset))))
    test_dataset = torch.utils.data.Subset(dataset, range(int(0.9*len(dataset)), len(dataset)))
    return train_dataset, val_dataset, test_dataset

In [None]:
batch_size = 40
train_dataset, val_dataset, test_dataset = train_val_test_split(dataset)
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True, drop_last=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size = batch_size, drop_last=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, drop_last=True, num_workers=4, pin_memory=True)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def to_device(data, device):
      if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
      return data.to(device, non_blocking = True)

class DeviceDataLoader():
    # Wrap a dataloader to move data to a device
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
    # Yield a batch of data after moving it to device
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        return len(self.dl)

In [None]:
train_loader = DeviceDataLoader(train_loader, device)
val_loader = DeviceDataLoader(val_loader, device)
test_loader = DeviceDataLoader(test_loader, device)

In [None]:
def mean_IOU(target, predicted): 
    iousum = 0
    for i in range(target.shape[0]):
        target_arr = target[i, :, :, :].clone().detach().cpu().numpy().argmax(0)
        predicted_arr = predicted[i, :, :, :].clone().detach().cpu().numpy().argmax(0)
        
        intersection = np.logical_and(target_arr, predicted_arr).sum()
        union = np.logical_or(target_arr, predicted_arr).sum()
        if union == 0:
            iou_score = 0
        else :
            iou_score = intersection / union
        iousum +=iou_score
        
    miou = iousum/target.shape[0]
    return miou

In [None]:
def pixel_acc(target, predicted):     
    accsum=0
    for i in range(target.shape[0]):
        target_arr = target[i, :, :, :].clone().detach().cpu().numpy().argmax(0)
        predicted_arr = predicted[i, :, :, :].clone().detach().cpu().numpy().argmax(0)
        
        same = (target_arr == predicted_arr).sum()
        a, b = target_arr.shape
        total = a*b
        accsum += same/total
    
    pixel_accuracy = accsum/target.shape[0]        
    return pixel_accuracy

In [None]:
class DeepLabv3(nn.Module):
    def __init__(self):
        super(DeepLabv3, self).__init__()
        self.deep_lab_v3 = models.segmentation.deeplabv3_resnet50(pretrained=0, 
                                                 progress=1)
        
        # two classes will be our output since our mask image (224x224x2) has two channels
        self.deep_lab_v3.classifier = DeepLabHead(2048, 2)
        
    def forward(self, input):
        output = self.deep_lab_v3(input)['out']
        return output

In [None]:
model = DeepLabv3()
to_device(model, device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
loss_fn = nn.BCEWithLogitsLoss()
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.8)

In [None]:
def fit(epochs, optimizer, lr_scheduler, model, loss_fn, train_loader, val_loader, last_ckpt_path = None):   
    tr_loss_arr = []
    val_loss_arr = []
    meanioutrain = []
    pixelacctrain = []
    meanioutest = []
    pixelacctest = []
    prev_epoch = 0
    
    if last_ckpt_path != None :
        checkpoint = torch.load(last_ckpt_path)
        prev_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = to_device(v, device)
                    tr_loss_arr =  checkpoint['Training Loss']
        val_loss_arr =  checkpoint['Validation Loss']
        meanioutrain =  checkpoint['MeanIOU train']
        pixelacctrain =  checkpoint['PixelAcc train']
        meanioutest =  checkpoint['MeanIOU test']
        pixelacctest =  checkpoint['PixelAcc test']
        to_device(model, device)
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        pixelacc = 0
        meaniou = 0
        
        pbar = tqdm(train_loader, total = len(train_loader))
        for input, mask in pbar:
            torch.cuda.empty_cache()
            model.train()
            input = input.float()
            mask = mask.float()
            pred = model(input)
            loss = loss_fn(pred, mask)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            tr_loss_arr.append(loss.item())
            meanioutrain.append(mean_IOU(mask, pred))
            pixelacctrain.append(pixel_acc(mask, pred))
            pbar.set_postfix({'Epoch':epoch+1+prev_epoch, 
                              'Training Loss': np.mean(tr_loss_arr),
                              'Mean IOU': np.mean(meanioutrain),
                              'Pixel Acc': np.mean(pixelacctrain)})
            
        with torch.no_grad():
            
            val_loss = 0
            pbar = tqdm(val_loader, total = len(val_loader))
            for input, mask in pbar:
                torch.cuda.empty_cache()
                input = input.float()
                mask = mask.float()
                model.eval()
                pred = model(input)
                
                val_loss_arr.append(loss_fn(pred, mask).item())
                pixelacctest.append(pixel_acc(mask, pred))
                meanioutest.append(mean_IOU(mask, pred))
                
                pbar.set_postfix({'Epoch':epoch+1+prev_epoch, 
                                  'Validation Loss': np.mean(val_loss_arr),
                                  'Mean IOU': np.mean(meanioutest),
                                  'Pixel Acc': np.mean(pixelacctest)
                                 })
        
        checkpoint = {
            'epoch':epoch+1+prev_epoch,
            'state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'Training Loss': tr_loss_arr,
            'Validation Loss':val_loss_arr,
            'MeanIOU train':meanioutrain, 
            'PixelAcc train':pixelacctrain, 
            'MeanIOU test':meanioutest, 
            'PixelAcc test':pixelacctest
        }
        torch.save(checkpoint, f"{epoch+1+prev_epoch}.pth")
        lr_scheduler.step()
        
    return tr_loss_arr, val_loss_arr, meanioutrain, pixelacctrain, meanioutest, pixelacctest

In [None]:
retval = [fit(3, optimizer, lr_scheduler, model, loss_fn, train_loader, val_loader)]