In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import numpy as np
from PIL import Image, ImageDraw
import pandas as pd
from utils import utils
from models.unet import UNet
import os
import wandb
import time

wandb.init()

  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhhebb[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
transform = A.Compose(
    [
        A.Normalize(),
        A.Resize(480, 480),
        A.RandomCrop(320, 320),
        # A.ColorJitter(brightness=.05, contrast=.05, saturation=.05, hue=.05, p=.2),
        A.Affine(translate_percent=.2),
        A.Rotate(limit=30),
        ToTensorV2(),
    ]
)

transform_valid = A.Compose(
    [
        A.Normalize(),
        ToTensorV2(),
    ]
)

ds = utils.Dataset_synth(transform)
train_set, val_set = torch.utils.data.random_split(ds, [int(len(ds)*.8), len(ds)-int(len(ds)*.8)])

train_loader = DataLoader(train_set, batch_size=16, shuffle=True, num_workers=2, pin_memory=True)
valid_loader = DataLoader(val_set, batch_size=32)

In [3]:
img, mask = ds.__getitem__(1)
# Image.fromarray(img.numpy().transpose(1, 2, 0).astype(np.uint8)*255)

In [4]:
import matplotlib.pyplot as plt
img = img.numpy()
mask = mask.numpy()
print(img.shape, mask.shape)
# plt.imshow(tmp, cmap='gray')

(3, 320, 320) (320, 320)


In [5]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        weights = FCN_ResNet50_Weights.DEFAULT
        self.model = fcn_resnet50(num_classes=21) # 
        self.conv = nn.Conv2d(21, 3, 1)

    def forward(self, x):
        x = self.model(x)['out']
        x = self.conv(x)
        x = torch.sigmoid(x)

        return x

In [6]:
def BCEDice(pred, gt):
    criterion = nn.BCELoss()
    
    bce = criterion(pred, gt)
    dice = 1 - get_dice(pred, gt)
    loss = bce + dice

    return loss

def get_dice(pred, gt):
    eps = 1e-5
    summ = torch.sum(gt) + torch.sum(pred)
    inter = torch.sum(gt * pred)
    dice = 2 * inter / (summ + eps)
    
    return dice

def get_dice_metric(pred, gt):
    eps = 1e-5
    pred = pred > .5
    summ = torch.sum(gt) + torch.sum(pred)
    inter = torch.sum(gt * pred)
    dice = 2 * inter / (summ + eps)
    
    return dice
    
model = Model().cuda()
# model = UNet(n_channels=3, n_classes=3).cuda()
wandb.watch(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)

In [None]:
# loop
model.train()
for e, epoch in enumerate(range(100)):
    total_loss = 0
    start = time.time()
    for imgs, masks in train_loader:
        imgs, masks = imgs.cuda(), masks.cuda()
        pred = model(imgs)
        masks = torch.stack([masks for i in range(3)], dim=-1) # 
        masks = masks.permute(0, 3, 1, 2)

        loss = BCEDice(pred, masks)
        total_loss += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # logging batch
        
    with torch.no_grad():
        dice = 0
        # valid
        for batch_idx, (imgs, masks) in enumerate(valid_loader):
            imgs, masks = imgs.cuda(), masks.cuda()
            masks = torch.stack([masks for i in range(3)], dim=-1) # 
            masks = masks.permute(0, 3, 1, 2)

            pred = model(imgs)
            dice += get_dice_metric(pred, masks)

            # test output save
            if e % 5 == 0:
                for sample_idx, pre in enumerate(pred):
                    im = imgs[sample_idx]
                    im = im.cpu().numpy().transpose(1, 2, 0)
                    im = (im * (0.229, 0.224, 0.225) + (0.485, 0.456, 0.406)) * 255
                    im = im.astype(np.uint8)
                    pred = ((pre.squeeze().detach().cpu().numpy()).transpose(1, 2, 0) * 255).astype(np.uint8)

                    im = np.clip(im + pred, 0, 255)
                    im = Image.fromarray(im)
                    save_path = os.path.join(r'\\wsl.localhost\Ubuntu-20.04\home\hebb\ml\project_hand\hand_seg\output\test', f"{batch_idx}_{sample_idx}.jpg")
                    im.save(save_path)

        # print('dice: ', dice.item() / len(valid_loader), 'total loss: ', total_loss.item() / len(train_loader))
        
        wandb.log(
            {
                'dice': dice.item() / len(valid_loader), 
                'loss': total_loss.item() / len(train_loader),
                'elapse': time.time() - start
            }
        )

    torch.save(model.state_dict(), 'ckpt.pt')
    # logging epoch

In [None]:
# specific directory test

import os
from glob import glob

eval_path = r'\\wsl.localhost\Ubuntu-20.04\home\hebb\ml\datasets\egohand\_LABELLED_SAMPLES\PUZZLE_OFFICE_T_S'

model = Model().cuda()
model.load_state_dict(torch.load('ckpt.pt'))

with torch.no_grad():
    for path in glob(f'{eval_path}\*'):
        if 'frame' not in path:
            continue
        
        img = np.array(Image.open(path))
        transformed = torch.unsqueeze(transform_valid(image=img)['image'], 0).cuda()
        # transformed = torch.unsqueeze(torch.tensor(img).permute(2, 0, 1), 0).type(torch.float).cuda()
        pred = model(transformed)
        pred = pred.squeeze().permute(1, 2, 0) > .5
        image = Image.fromarray((pred.detach().cpu().numpy()*255).astype(np.uint8))
        base = os.path.basename(path)
        image.save(f'./output/{base}')
