# Torch Segmentation Model with Lab Dataset

Based off of:

https://towardsdatascience.com/train-neural-net-for-semantic-segmentation-with-pytorch-in-50-lines-of-code-830c71a6544f
https://github.com/sagieppel/Train-Semantic-Segmentation-Net-with-Pytorch-In-50-Lines-Of-Code 

MIT License - use available for commercial use

In [2]:
import os
import numpy as np
import cv2
import torchvision.models.segmentation
import torch
import torchvision.transforms as tf

Learning_Rate=1e-5
width=height=800 # image width and height
batchSize=3

In [3]:
TrainFolder="Data/LabPicsV1/Simple/Train/"
ListImages=os.listdir(os.path.join(TrainFolder, "Image"))

In [4]:
transformImg=tf.Compose([tf.ToPILImage(),tf.Resize((height,width)), tf.ToTensor(),tf.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

transformAnn=tf.Compose([tf.ToPILImage(),tf.Resize((height,width)), tf.ToTensor()])

In [11]:
def ReadRandomImage(show=False):   
    
    idx=np.random.randint(0,len(ListImages)) # Pick random image   
    Img=cv2.imread(os.path.join(TrainFolder, "Image",ListImages[idx]))  
    Filled =  cv2.imread(os.path.join(TrainFolder,   "Semantic/16_Filled", ListImages[idx].replace("jpg","png")),0)       
 
    Vessel =  cv2.imread(os.path.join(TrainFolder, "Semantic/1_Vessel", ListImages[idx].replace("jpg","png")),0) 
    AnnMap = np.zeros(Img.shape[0:2],np.float32) # Segmentation map

    if show:
        print("Image path: ", str(os.path.join(TrainFolder, "Image",ListImages[idx])))
        # show the image, provide window name first
        cv2.imshow('Img', Img)
        cv2.waitKey(0)
        cv2.imshow('Filled', Filled)
        cv2.waitKey(0)
        cv2.imshow('Vessel', Vessel)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        # cv2.imshow('AnnMap', np.uint8(AnnMap))
        # cv2.waitKey(0)
        # cv2.destroyAllWindows()
        print("Image: ", Img)
        print("Filled: ", Filled)
        print("Vessel: ", Vessel)
        print("AnnMap: ", AnnMap)
    
    if Vessel is not None:  
        AnnMap[ Vessel == 1 ] = 1    
    if Filled is not None:  
        AnnMap[ Filled  == 1 ] = 2
    
    Img=transformImg(Img)
    AnnMap=transformAnn(AnnMap)

    return Img,AnnMap, Vessel

In [12]:
# Print test images form dataset
Img,AnnMap, Vessel = ReadRandomImage(show=True)

Image path:  Data/LabPicsV1/Simple/Train/Image/ChemPlayer_Nickel chloride and boride-screenshot.jpg
Image:  [[[13 12 16]
  [13 12 16]
  [13 12 16]
  ...
  [30 25 26]
  [30 25 26]
  [30 25 26]]

 [[13 12 16]
  [13 12 16]
  [13 12 16]
  ...
  [35 30 31]
  [35 30 31]
  [35 30 31]]

 [[13 12 16]
  [13 12 16]
  [13 12 16]
  ...
  [36 31 32]
  [36 31 32]
  [36 31 32]]

 ...

 [[79 79 91]
  [74 74 86]
  [76 76 88]
  ...
  [46 49 57]
  [49 52 60]
  [53 56 64]]

 [[76 76 88]
  [67 67 79]
  [68 68 80]
  ...
  [46 49 57]
  [49 52 60]
  [51 54 62]]

 [[69 69 81]
  [68 68 80]
  [69 69 81]
  ...
  [42 45 53]
  [46 49 57]
  [46 49 57]]]
Filled:  [[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
Vessel:  [[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
AnnMap:  [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.

In [5]:
#--------------Load batch of images-----------------------------------------------------
def LoadBatch(): # Load batch of images
    images = torch.zeros([batchSize,3,height,width])
    ann = torch.zeros([batchSize, height, width])

    for i in range(batchSize):
        images[i],ann[i],_=ReadRandomImage()
    
    return images, ann

In [6]:
#--------------Load and set net and optimizer-------------------------------------
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
Net = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True) # Load net
Net.classifier[4] = torch.nn.Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1)) # Change final layer to 3 classes
Net=Net.to(device)
optimizer=torch.optim.Adam(params=Net.parameters(),lr=Learning_Rate) # Create adam optimizer

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "
Downloading: "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth" to /home/carwyn/.cache/torch/hub/checkpoints/deeplabv3_resnet50_coco-cd0a2569.pth
100%|██████████| 161M/161M [00:04<00:00, 37.3MB/s] 


In [7]:
#----------------Train--------------------------------------------------------------------------
for itr in range(10000): # Training loop
   images,ann=LoadBatch() # Load taining batch
   images=torch.autograd.Variable(images,requires_grad=False).to(device) # Load image
   ann = torch.autograd.Variable(ann, requires_grad=False).to(device) # Load annotation
   Pred=Net(images)['out'] # make prediction
   Net.zero_grad()
   criterion = torch.nn.CrossEntropyLoss() # Set loss function
   Loss=criterion(Pred,ann.long()) # Calculate cross entropy loss
   Loss.backward() # Backpropogate loss
   optimizer.step() # Apply gradient descent change to weight
   seg = torch.argmax(Pred[0], 0).cpu().detach().numpy()  # Get  prediction classes
   print(itr,") Loss=",Loss.data.cpu().numpy())
   
   if itr % 1000 == 0: #Save model weight once every 60k steps permenant file
        print("Saving Model" +str(itr) + ".torch")
        torch.save(Net.state_dict(),   str(itr) + ".torch")

0 ) Loss= 1.0658557
Saving Model0.torch
1 ) Loss= 1.0500449
2 ) Loss= 1.036717
3 ) Loss= 0.9943116
4 ) Loss= 1.0568877
5 ) Loss= 1.0175107
6 ) Loss= 1.0378245
7 ) Loss= 1.0149018
8 ) Loss= 1.0815524
9 ) Loss= 0.9888705
10 ) Loss= 1.0671489
11 ) Loss= 0.99889594
12 ) Loss= 1.0209575
13 ) Loss= 0.9810101
14 ) Loss= 0.99607784
15 ) Loss= 0.9659469
16 ) Loss= 0.95963675
17 ) Loss= 0.98452586
18 ) Loss= 0.9917831
19 ) Loss= 0.96990806
20 ) Loss= 0.95592886
21 ) Loss= 1.036369
22 ) Loss= 0.8784146
23 ) Loss= 0.9541694
24 ) Loss= 0.94949543
25 ) Loss= 0.9399611
26 ) Loss= 0.90972096
27 ) Loss= 1.0184376
28 ) Loss= 0.9256802
29 ) Loss= 0.9892294
30 ) Loss= 0.9155993
31 ) Loss= 0.9524102
32 ) Loss= 0.93066067
33 ) Loss= 0.90166485
34 ) Loss= 0.89346594
35 ) Loss= 0.9617312
36 ) Loss= 0.8970508
37 ) Loss= 0.8774571
38 ) Loss= 0.9642111
39 ) Loss= 0.9433593
40 ) Loss= 0.87536484
41 ) Loss= 0.93168795
42 ) Loss= 0.87668705
43 ) Loss= 0.8768035
44 ) Loss= 0.8927781
45 ) Loss= 0.90663785
46 ) Loss= 

KeyboardInterrupt: 