# Torch Segmentation Model with Lab Dataset

Based off of:

https://towardsdatascience.com/train-neural-net-for-semantic-segmentation-with-pytorch-in-50-lines-of-code-830c71a6544f
https://github.com/sagieppel/Train-Semantic-Segmentation-Net-with-Pytorch-In-50-Lines-Of-Code 

MIT License - use available for commercial use

In [30]:
import os
import numpy as np
import cv2
import torchvision.models.segmentation
import torch
import torchvision.transforms as tf
import random

Learning_Rate=1e-5
width=height=800 # image width and height
batchSize=3

In [31]:
TrainFolder="Data/LabPicsV1/Simple/Train/"
ListImages=os.listdir(os.path.join(TrainFolder, "Image"))

In [32]:
transformImg=tf.Compose([tf.ToPILImage(),tf.Resize((height,width)), tf.ToTensor(),tf.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

transformAnn=tf.Compose([tf.ToPILImage(),tf.Resize((height,width)), tf.ToTensor()])

In [33]:
# Read custom image set into cache so that we don't have to load on every single run
custom_image_loaded = []
custom_masks_loaded = []

def loadImages(img_arr, mask_arr, full_path_images, full_path_masks): # full path is the parent folder of images
    data = os.listdir(full_path_images)
    for file in data:
        img_single_path = os.path.join(full_path_images, file)
        mask_single_path = os.path.join(full_path_masks, file)

        img_arr.append(cv2.imread(img_single_path))
        temp_mask = cv2.imread(os.path.join(mask_single_path.replace("jpeg","png").replace("jpg","png"))).max(axis=2) # perhaps switch custom_masks[idx] with custom_imgs[idx] to match image
        temp_mask[temp_mask > 0] = 1
        mask_arr.append(temp_mask)

full_path_images = "../../Skin_Anatomical_Image_Dataset/simple_image_config/images"
full_path_masks = "../../Skin_Anatomical_Image_Dataset/simple_image_config/masks"
loadImages(custom_image_loaded, custom_masks_loaded, full_path_images, full_path_masks)

print("Length of images and masked images arrays: ", len(custom_image_loaded), ", ", len(custom_masks_loaded))

Length of images and masked images arrays:  69 ,  69


In [34]:
probability = 0.25
second_train_folder = "../../Skin_Anatomical_Image_Dataset/simple_image_config"

cust_img_path = os.path.join(second_train_folder, "images")
cust_mask_path = os.path.join(second_train_folder, "masks")

custom_imgs = os.listdir(cust_img_path)
custom_masks = os.listdir(cust_mask_path)

# Get length of directory
len_custom_masks = len(custom_masks)
print("Number of files in custom directory: ", len_custom_masks)

def ReadRandomImage(show=False):

    # Probabilistically read in an image that is from our custom dataset
    if random.random() < probability:
        idx=np.random.randint(0,len_custom_masks)
        Filled = None

        # Replace loading by our cache
        # Img=cv2.imread(os.path.join(cust_img_path,custom_imgs[idx]))
        # Vessel =  cv2.imread(os.path.join(cust_mask_path, custom_imgs[idx].replace("jpeg","png").replace("jpg","png"))).max(axis=2) # perhaps switch custom_masks[idx] with custom_imgs[idx] to match image
        # Vessel[Vessel > 0] = 1

        Img = custom_image_loaded[idx]
        Vessel = custom_masks_loaded[idx]        
    else:
        idx=np.random.randint(0,len(ListImages)) # Pick random image   
        Img=cv2.imread(os.path.join(TrainFolder, "Image",ListImages[idx]))  
        Filled =  cv2.imread(os.path.join(TrainFolder,   "Semantic/16_Filled", ListImages[idx].replace("jpg","png")),0)    
        Vessel =  cv2.imread(os.path.join(TrainFolder, "Semantic/1_Vessel", ListImages[idx].replace("jpg","png")),0)

    if show:
        print("Image path: ", str(os.path.join(TrainFolder, "Image",ListImages[idx])))
        # show the image, provide window name first
        cv2.imshow('Img', Img)
        cv2.waitKey(0)
        # cv2.imshow('Filled', Filled)
        # cv2.waitKey(0)
        cv2.imshow('Vessel', Vessel)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        # print("Filled: ", Filled)
        print("Vessel num: ", np.count_nonzero(Vessel), "/", Vessel.shape, ", range of values: [", np.min(Vessel), ", ", np.max(Filled), "]")
        print("numNan: ", sum(~np.isnan(Vessel)))
        print("Vessel: ", Vessel)

    AnnMap = np.zeros(Img.shape[0:2],np.float32)
    
    if Vessel is not None:  
        AnnMap[ Vessel == 1 ] = 1    
    if Filled is not None:  
        AnnMap[ Filled  == 1 ] = 2

    Img=transformImg(Img)
    AnnMap=transformAnn(AnnMap)

    return Img,AnnMap, Vessel

Number of files in custom directory:  69


In [35]:
# Print test images form dataset
Img,AnnMap, Vessel = ReadRandomImage(show=True)

Image path:  Data/LabPicsV1/Simple/Train/Image/NurdeRage_Make Chlorotoluene (mixture of isomers) 1st step in making Pyrimethamine-screenshot (1).jpg
Vessel num:  81089 / (480, 854) , range of values: [ 0 ,  254 ]
numNan:  [480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480
 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480
 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480
 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480
 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480
 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480
 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480
 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480
 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480
 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480 480
 480 480 480 480 480 480 480 480 480 480 480 480

In [36]:
#--------------Load batch of images-----------------------------------------------------
def LoadBatch(): # Load batch of images
    images = torch.zeros([batchSize,3,height,width])
    ann = torch.zeros([batchSize, height, width])

    for i in range(batchSize):
        images[i],ann[i],_=ReadRandomImage()
    
    return images, ann

In [37]:
#--------------Load and set net and optimizer-------------------------------------
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
Net = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True) # Load net
Net.classifier[4] = torch.nn.Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1)) # Change final layer to 3 classes
Net=Net.to(device)
optimizer=torch.optim.Adam(params=Net.parameters(),lr=Learning_Rate) # Create adam optimizer

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


In [38]:
#----------------Train--------------------------------------------------------------------------
for itr in range(25001): # Training loop
   images,ann=LoadBatch() # Load taining batch
   images=torch.autograd.Variable(images,requires_grad=False).to(device) # Load image
   ann = torch.autograd.Variable(ann, requires_grad=False).to(device) # Load annotation
   Pred=Net(images)['out'] # make prediction
   Net.zero_grad()
   criterion = torch.nn.CrossEntropyLoss() # Set loss function
   Loss=criterion(Pred,ann.long()) # Calculate cross entropy loss
   Loss.backward() # Backpropogate loss
   optimizer.step() # Apply gradient descent change to weight
   seg = torch.argmax(Pred[0], 0).cpu().detach().numpy()  # Get  prediction classes
   print(itr,") Loss=",Loss.data.cpu().numpy())
   
   if itr % 1000 == 0: #Save model weight once every 60k steps permenant file
        print("Saving Model" +str(itr) + ".torch")
        torch.save(Net.state_dict(),  "models/custom_dataset_p_0.25_" + str(itr) + ".torch")

0 ) Loss= 1.0641925
Saving Model0.torch
1 ) Loss= 1.063836
2 ) Loss= 1.1081104
3 ) Loss= 1.123803
4 ) Loss= 1.070643
5 ) Loss= 1.0465094
6 ) Loss= 1.1295844
7 ) Loss= 0.9797787
8 ) Loss= 1.0345088
9 ) Loss= 1.136242
10 ) Loss= 1.0213658
11 ) Loss= 0.92505217
12 ) Loss= 1.078062
13 ) Loss= 1.0253958
14 ) Loss= 1.0435727
15 ) Loss= 1.0635463
16 ) Loss= 1.0249176
17 ) Loss= 0.9800817
18 ) Loss= 0.9944918
19 ) Loss= 0.99445367
20 ) Loss= 0.9344411
21 ) Loss= 0.9990286
22 ) Loss= 0.98906815
23 ) Loss= 1.0092206
24 ) Loss= 0.92696667
25 ) Loss= 0.91998005
26 ) Loss= 0.9849922
27 ) Loss= 0.93664545
28 ) Loss= 0.9297361
29 ) Loss= 0.92482203
30 ) Loss= 0.9136913
31 ) Loss= 0.94877505
32 ) Loss= 0.9470961
33 ) Loss= 0.9528908
34 ) Loss= 0.87010455
35 ) Loss= 0.91237706
36 ) Loss= 0.8769731
37 ) Loss= 0.9729829
38 ) Loss= 0.93276745
39 ) Loss= 0.94826555
40 ) Loss= 0.8501016
41 ) Loss= 0.9673283
42 ) Loss= 0.84300977
43 ) Loss= 0.8974139
44 ) Loss= 0.85204285
45 ) Loss= 0.9220344
46 ) Loss= 1.00