In [22]:
import torch 
import torch.nn as nn
from torchvision.models.segmentation import deeplabv3_resnet50
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, utils
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
import torchvision.transforms.functional as TF
from torchmetrics.classification import F1Score
from torchmetrics import Accuracy 
from sklearn.model_selection import train_test_split
import torch.optim as optim
from utils_segmentation import p3



In [23]:
def prepare_model(num_classes=8):
    model = deeplabv3_resnet50(weights='DEFAULT')
    model.classifier[4] = nn.Conv2d(256, num_classes, 1)
    model.aux_classifier[4] = nn.Conv2d(256, num_classes, 1)
    return model

In [24]:
def Transform(sample):
    mask = sample['mask']
    image = sample['image']
    b = np.random.uniform(np.random.uniform())
    b = TF.adjust_brightness(image,b)
    #TF.crop
    r = np.random.uniform()
    if r>0.5:
        image = TF.hflip(image)
        mask = TF.hflip(mask)
        
    r = np.random.uniform()
    if r>0.5:
        image = TF.vflip(image)
        mask = TF.vflip(mask)        
    return sample

In [51]:
class DscSegmentationDataset(Dataset):
    def __init__(self, image_list, mask_list, transform=None):
        """
        Arguments:
            image_list (string): all the masks.
            mask_list (string): all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.image_list = image_list
        self.mask_list = mask_list
        self.transform = transform

    def __len__(self):
        return len(image_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = image_list[idx]
        mask_name = mask_list[idx]
        print(img_name)
        print(mask_name)
        image = Image.open(img_name)
        mask = Image.open(mask_name)
        image = image.convert("RGB")
        sample = {'image': image, 'mask': mask}

        if self.transform:
            sample = self.transform(sample)

        mask = sample['mask']
        image = sample['image']       
        mask = TF.pil_to_tensor(mask)
        image = TF.pil_to_tensor(image)
        image = image.type(torch.FloatTensor)/255
        image = TF.normalize(image,mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        sample = {'image': image, 'mask': mask}

        return sample

In [52]:

version='V1'
p1 = os.path.join('../../','work','CookIRCamET','Images','CookHY2023',version,'TifPng','RGB')
p2 = os.path.join('../../','work','CookIRCamET','Images','CookHY2023',version,'TifPng')
p11 = os.path.join('../../','work','CookIRCamET','Images','CprlHY2023',version,'TifPng','RGB')
p22 = os.path.join('../../','work','CookIRCamET','Images','CprlHY2023',version,'TifPng')

image_list=[]
mask_list=[]
n_img=0
for di,do in zip([p1,p11],[p2,p22]):
    fs=os.listdir(di)
    print(di)
    print(do)
    for f in fs:
        if 'bgr' in f:
            if os.path.exists(os.path.join(do,'Masks',f.split('_bgr')[0]+'_class8.png')):
                image_list.append(os.path.join(di,f))            
                mask_list.append(os.path.join(do,'Masks',f.split('_bgr')[0]+'_class8.png'))
            

../../work/CookIRCamET/Images/CookHY2023/V1/TifPng/RGB
../../work/CookIRCamET/Images/CookHY2023/V1/TifPng
../../work/CookIRCamET/Images/CprlHY2023/V1/TifPng/RGB
../../work/CookIRCamET/Images/CprlHY2023/V1/TifPng


In [53]:
images_train, images_test, masks_train, masks_test = train_test_split(image_list, mask_list, test_size=0.2)

In [54]:
dsc_train = DscSegmentationDataset(images_train,masks_train,transform=Transform)

In [55]:
dsc_test = DscSegmentationDataset(images_test,masks_test,transform=None)

In [56]:
model = prepare_model(num_classes=8)

In [57]:
dsc_test.transform

In [58]:
dsc_train.transform


<function __main__.Transform(sample)>

In [64]:
train_loader = DataLoader(dsc_train, batch_size=14, shuffle=True, num_workers=0)

In [65]:
test_loader = DataLoader(dsc_test, batch_size=len(dsc_test), shuffle=True, num_workers=0)

In [66]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [67]:
len(images_train)

42

In [None]:
for epoch in range(100):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [inputs, masks]
        image = data['image']
        mask = data['mask'].squeeze().type(torch.LongTensor)
        
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(image)
        loss = criterion(outputs['out'], mask)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2 == 1:    # print every 2 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2:.3f}')
            running_loss = 0.0

print('Finished Training')

../../work/CookIRCamET/Images/CookHY2023/V1/TifPng/RGB/20230109232457_-117.081797_46.781650_bgr.png
../../work/CookIRCamET/Images/CookHY2023/V1/TifPng/Masks/20230109232457_-117.081797_46.781650_class8.png
../../work/CookIRCamET/Images/CprlHY2023/V1/TifPng/RGB/20230506214817_-102.095543_35.188177_bgr.png
../../work/CookIRCamET/Images/CprlHY2023/V1/TifPng/Masks/20230506214817_-102.095543_35.188177_class8.png
../../work/CookIRCamET/Images/CookHY2023/V1/TifPng/RGB/20221102210331_-117.081888_46.781508_bgr.png
../../work/CookIRCamET/Images/CookHY2023/V1/TifPng/Masks/20221102210331_-117.081888_46.781508_class8.png
../../work/CookIRCamET/Images/CprlHY2023/V1/TifPng/RGB/20230415200907_nofix_bgr.png
../../work/CookIRCamET/Images/CprlHY2023/V1/TifPng/Masks/20230415200907_nofix_class8.png
../../work/CookIRCamET/Images/CprlHY2023/V1/TifPng/RGB/20230415203912_nofix_bgr.png
../../work/CookIRCamET/Images/CprlHY2023/V1/TifPng/Masks/20230415203912_nofix_class8.png
../../work/CookIRCamET/Images/CookHY202

In [None]:
#evaluation
f1 = F1Score(task="multiclass", num_classes=8)
acc = Accuracy(task="multiclass",num_classes=8)
for i, data in enumerate(test_loader, 0):
    # get the inputs; data is a list of [inputs, masks]
    image = data['image']
    mask = data['mask'].squeeze().type(torch.LongTensor)

    outputs = model(image)
    F1 = f1(outputs['out'].argmax(1).squeeze(), mask)
    A = acc(outputs['out'].argmax(1).squeeze(), mask)



In [None]:
print(F1)

In [None]:
print(A)

In [None]:
torch.save(model.state_dict(), os.path.join(p3,'model_deeplab_'+version+'.sv'))