# Torch Segmentation Model with Lab Dataset

Based off of:

https://towardsdatascience.com/train-neural-net-for-semantic-segmentation-with-pytorch-in-50-lines-of-code-830c71a6544f
https://github.com/sagieppel/Train-Semantic-Segmentation-Net-with-Pytorch-In-50-Lines-Of-Code 

MIT License - use available for commercial use

In [1]:
import os
import numpy as np
import cv2
import torchvision.models.segmentation
import torch
import torchvision.transforms as tf
import random

Learning_Rate=1e-5
width=height=800 # image width and height
batchSize=3

In [2]:
TrainFolder="Data/LabPicsV1/Simple/Train/"
ListImages=os.listdir(os.path.join(TrainFolder, "Image"))

In [3]:
transformImg=tf.Compose([tf.ToPILImage(),tf.Resize((height,width)), tf.ToTensor(),tf.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

transformAnn=tf.Compose([tf.ToPILImage(),tf.Resize((height,width)), tf.ToTensor()])

In [18]:
probability = 1
second_train_folder = "../../Skin_Anatomical_Image_Dataset/simple_image_config"

cust_img_path = os.path.join(second_train_folder, "images")
cust_mask_path = os.path.join(second_train_folder, "masks")

custom_imgs = os.listdir(cust_img_path)
custom_masks = os.listdir(cust_mask_path)

# Get length of directory
len_custom_masks = len(custom_masks)
print("Number of files in custom directory: ", len_custom_masks)

def ReadRandomImage(show=False):

    # Probabilistically read in an image that is from our custom dataset
    if random.random() < probability:
        idx=np.random.randint(0,len_custom_masks)

        # print("Idx: ", idx, " filename: ", custom_imgs[idx], ", ", custom_masks[idx])
        
        Img=cv2.imread(os.path.join(cust_img_path,custom_imgs[idx]))
        Vessel =  cv2.imread(os.path.join(cust_mask_path, custom_masks[idx])).max(axis=2) # perhaps switch custom_masks[idx] with custom_imgs[idx] to match image
        Vessel[Vessel == None] = 1
        Vessel.astype(int)
        AnnMap = np.zeros(Img.shape[0:2],np.float32)
        Filled = None
        # print("Vessel Shape: ", Vessel.shape)
    else:
        idx=np.random.randint(0,len(ListImages)) # Pick random image   
        Img=cv2.imread(os.path.join(TrainFolder, "Image",ListImages[idx]))  
        Filled =  cv2.imread(os.path.join(TrainFolder,   "Semantic/16_Filled", ListImages[idx].replace("jpg","png")),0)       
    
        Vessel =  cv2.imread(os.path.join(TrainFolder, "Semantic/1_Vessel", ListImages[idx].replace("jpg","png")),0)
        # print("Vessel Shape2: ", Vessel.shape)
        AnnMap = np.zeros(Img.shape[0:2],np.float32) # Segmentation map

    if show:
        print("Image path: ", str(os.path.join(TrainFolder, "Image",ListImages[idx])))
        # show the image, provide window name first
        cv2.imshow('Img', Img)
        cv2.waitKey(0)
        # cv2.imshow('Filled', Filled)
        # cv2.waitKey(0)
        cv2.imshow('Vessel', Vessel)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        # print("Filled: ", Filled)
        print("Vessel num: ", np.count_nonzero(Vessel), "/", Vessel.shape, ", range of values: [", np.min(Vessel), ", ", np.max(Filled), "]")
        print("Vessel: ", Vessel)
        print("AnnMap: ", AnnMap)
    
    if Vessel is not None:  
        AnnMap[ Vessel == 1 ] = 1    
    if Filled is not None:  
        AnnMap[ Filled  == 1 ] = 2

    Img=transformImg(Img)
    AnnMap=transformAnn(AnnMap)

    return Img,AnnMap, Vessel

Number of files in custom directory:  69


In [19]:
# Print test images form dataset
Img,AnnMap, Vessel = ReadRandomImage(show=True)

Image path:  Data/LabPicsV1/Simple/Train/Image/IMG_20190302_220740.jpg
Filled num:  4186416 / (3024, 4032) , range of values: [ 0 ,  None ]
Vessel:  [[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
AnnMap:  [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]


In [6]:
#--------------Load batch of images-----------------------------------------------------
def LoadBatch(): # Load batch of images
    images = torch.zeros([batchSize,3,height,width])
    ann = torch.zeros([batchSize, height, width])

    for i in range(batchSize):
        images[i],ann[i],_=ReadRandomImage()
    
    return images, ann

In [7]:
#--------------Load and set net and optimizer-------------------------------------
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
Net = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True) # Load net
Net.classifier[4] = torch.nn.Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1)) # Change final layer to 3 classes
Net=Net.to(device)
optimizer=torch.optim.Adam(params=Net.parameters(),lr=Learning_Rate) # Create adam optimizer

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


In [8]:
#----------------Train--------------------------------------------------------------------------
for itr in range(2001): # Training loop
   images,ann=LoadBatch() # Load taining batch
   images=torch.autograd.Variable(images,requires_grad=False).to(device) # Load image
   ann = torch.autograd.Variable(ann, requires_grad=False).to(device) # Load annotation
   Pred=Net(images)['out'] # make prediction
   Net.zero_grad()
   criterion = torch.nn.CrossEntropyLoss() # Set loss function
   Loss=criterion(Pred,ann.long()) # Calculate cross entropy loss
   Loss.backward() # Backpropogate loss
   optimizer.step() # Apply gradient descent change to weight
   seg = torch.argmax(Pred[0], 0).cpu().detach().numpy()  # Get  prediction classes
   print(itr,") Loss=",Loss.data.cpu().numpy())
   
   if itr % 1000 == 0: #Save model weight once every 60k steps permenant file
        print("Saving Model" +str(itr) + ".torch")
        torch.save(Net.state_dict(),  "custom_dataset_p_1_" + str(itr) + ".torch")

0 ) Loss= 1.1937019
Saving Model0.torch
1 ) Loss= 1.1891718
2 ) Loss= 1.1922604


KeyboardInterrupt: 