# Torch Segmentation Model with Lab Dataset

Based off of:

https://towardsdatascience.com/train-neural-net-for-semantic-segmentation-with-pytorch-in-50-lines-of-code-830c71a6544f
https://github.com/sagieppel/Train-Semantic-Segmentation-Net-with-Pytorch-In-50-Lines-Of-Code 

MIT License - use available for commercial use

In [1]:
import os
import numpy as np
import cv2
import torchvision.models.segmentation
import torch
import torchvision.transforms as tf
import random

Learning_Rate=1e-5
width=height=800 # image width and height
batchSize=3

In [2]:
TrainFolder="Data/LabPicsV1/Simple/Train/"
ListImages=os.listdir(os.path.join(TrainFolder, "Image"))

In [3]:
transformImg=tf.Compose([tf.ToPILImage(),tf.Resize((height,width)), tf.ToTensor(),tf.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

transformAnn=tf.Compose([tf.ToPILImage(),tf.Resize((height,width)), tf.ToTensor()])

In [20]:
probability = 1
second_train_folder = "../../Skin_Anatomical_Image_Dataset/simple_image_config"

cust_img_path = os.path.join(second_train_folder, "images")
cust_mask_path = os.path.join(second_train_folder, "masks")

custom_imgs = os.listdir(cust_img_path)
custom_masks = os.listdir(cust_mask_path)

# Get length of directory
len_custom_masks = len(custom_masks)
print("Number of files in custom directory: ", len_custom_masks)

def ReadRandomImage(show=False):

    # Probabilistically read in an image that is from our custom dataset
    if random.random() < probability:
        idx=np.random.randint(0,len_custom_masks)

        # print("Idx: ", idx, " filename: ", custom_imgs[idx], ", ", custom_masks[idx])
        
        Img=cv2.imread(os.path.join(cust_img_path,custom_imgs[idx]))
        Vessel =  cv2.imread(os.path.join(cust_mask_path, custom_imgs[idx].replace("jpeg","png").replace("jpg","png"))).max(axis=2) # perhaps switch custom_masks[idx] with custom_imgs[idx] to match image
        Vessel[Vessel > 0] = 1
        # Vessel[Vessel == None] = 1
        # Vessel.astype(int)

        # print("Vessel: ", Vessel)
        
        Filled = None
        # print("Vessel Shape: ", Vessel.shape)
    else:
        idx=np.random.randint(0,len(ListImages)) # Pick random image   
        Img=cv2.imread(os.path.join(TrainFolder, "Image",ListImages[idx]))  
        Filled =  cv2.imread(os.path.join(TrainFolder,   "Semantic/16_Filled", ListImages[idx].replace("jpg","png")),0)       
    
        Vessel =  cv2.imread(os.path.join(TrainFolder, "Semantic/1_Vessel", ListImages[idx].replace("jpg","png")),0)
        # print("Vessel Shape2: ", Vessel.shape)

    if show:
        print("Image path: ", str(os.path.join(TrainFolder, "Image",ListImages[idx])))
        # show the image, provide window name first
        cv2.imshow('Img', Img)
        cv2.waitKey(0)
        # cv2.imshow('Filled', Filled)
        # cv2.waitKey(0)
        cv2.imshow('Vessel', Vessel)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        # print("Filled: ", Filled)
        print("Vessel num: ", np.count_nonzero(Vessel), "/", Vessel.shape, ", range of values: [", np.min(Vessel), ", ", np.max(Filled), "]")
        print("numNan: ", sum(~np.isnan(Vessel)))
        print("Vessel: ", Vessel)

    AnnMap = np.zeros(Img.shape[0:2],np.float32)
    
    if Vessel is not None:  
        AnnMap[ Vessel == 1 ] = 1    
    if Filled is not None:  
        AnnMap[ Filled  == 1 ] = 2

    Img=transformImg(Img)
    AnnMap=transformAnn(AnnMap)

    return Img,AnnMap, Vessel

Number of files in custom directory:  69


In [21]:
# Print test images form dataset
Img,AnnMap, Vessel = ReadRandomImage(show=True)

Image path:  Data/LabPicsV1/Simple/Train/Image/IMG_20190103_085653.jpg
Vessel num:  5287598 / (3024, 4032) , range of values: [ 0 ,  None ]
numNan:  [3024 3024 3024 ... 3024 3024 3024]
Vessel:  [[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]


In [22]:
#--------------Load batch of images-----------------------------------------------------
def LoadBatch(): # Load batch of images
    images = torch.zeros([batchSize,3,height,width])
    ann = torch.zeros([batchSize, height, width])

    for i in range(batchSize):
        images[i],ann[i],_=ReadRandomImage()
    
    return images, ann

In [23]:
#--------------Load and set net and optimizer-------------------------------------
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
Net = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True) # Load net
Net.classifier[4] = torch.nn.Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1)) # Change final layer to 3 classes
Net=Net.to(device)
optimizer=torch.optim.Adam(params=Net.parameters(),lr=Learning_Rate) # Create adam optimizer

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


In [24]:
#----------------Train--------------------------------------------------------------------------
for itr in range(2001): # Training loop
   images,ann=LoadBatch() # Load taining batch
   images=torch.autograd.Variable(images,requires_grad=False).to(device) # Load image
   ann = torch.autograd.Variable(ann, requires_grad=False).to(device) # Load annotation
   Pred=Net(images)['out'] # make prediction
   Net.zero_grad()
   criterion = torch.nn.CrossEntropyLoss() # Set loss function
   Loss=criterion(Pred,ann.long()) # Calculate cross entropy loss
   Loss.backward() # Backpropogate loss
   optimizer.step() # Apply gradient descent change to weight
   seg = torch.argmax(Pred[0], 0).cpu().detach().numpy()  # Get  prediction classes
   print(itr,") Loss=",Loss.data.cpu().numpy())
   
   if itr % 1000 == 0: #Save model weight once every 60k steps permenant file
        print("Saving Model" +str(itr) + ".torch")
        torch.save(Net.state_dict(),  "custom_dataset_p_1_" + str(itr) + ".torch")

0 ) Loss= 1.0717262
Saving Model0.torch
1 ) Loss= 1.0926737
2 ) Loss= 1.0795858
3 ) Loss= 1.0208857
4 ) Loss= 1.0434695
5 ) Loss= 1.0357416
6 ) Loss= 0.9973864
7 ) Loss= 1.0180311
8 ) Loss= 0.9528891
9 ) Loss= 0.93231106
10 ) Loss= 0.97271144
11 ) Loss= 0.939484
12 ) Loss= 0.96754736
13 ) Loss= 0.9589764
14 ) Loss= 0.94653165
15 ) Loss= 0.9161028
16 ) Loss= 0.85357624
17 ) Loss= 0.90847313
18 ) Loss= 0.9022595
19 ) Loss= 0.94125146
20 ) Loss= 0.8548535
21 ) Loss= 0.8704541
22 ) Loss= 0.8629475
23 ) Loss= 0.8336007
24 ) Loss= 0.83568627
25 ) Loss= 0.8563431
26 ) Loss= 0.82525605
27 ) Loss= 0.8304503
28 ) Loss= 0.8468362
29 ) Loss= 0.82246107
30 ) Loss= 0.8023564
31 ) Loss= 0.77258694
32 ) Loss= 0.7548555
33 ) Loss= 0.75191295
34 ) Loss= 0.7925494
35 ) Loss= 0.8141155
36 ) Loss= 0.7801252
37 ) Loss= 0.719839
38 ) Loss= 0.772162
39 ) Loss= 0.7426101
40 ) Loss= 0.73222524
41 ) Loss= 0.71724224
42 ) Loss= 0.7016625
43 ) Loss= 0.7023671
44 ) Loss= 0.70575684
45 ) Loss= 0.6849401
46 ) Loss= 0