<a href="https://colab.research.google.com/github/hmaharaja/AI_stuff/blob/Steel-defect-classification/Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **1. Importing Kaggle Data**

---



In [0]:
#Link to competition: https://www.kaggle.com/c/severstal-steel-defect-detection/overview
#Upload the Kaggle JSON file (download it from the drive to your computer first)
from google.colab import files
files.upload()

In [0]:
#Mount the kaggle.json file
!ls -lha kaggle.json
!pip install -q kaggle
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [0]:
#Extract required files
!kaggle competitions download -c severstal-steel-defect-detection -q
!unzip -q train.csv.zip
!unzip -q train_images.zip -d train_images

In [0]:
# Include all import statements required over here
import csv
import numpy as np
np.set_printoptions(suppress=True)
import matplotlib.pyplot as plt
import seaborn as sn
from PIL import Image, ImageDraw
import math
import os
import shutil
import random
import cv2
from scipy import ndimage, misc
import itertools
from skimage import io, transform

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, transforms, utils, models

**List of global variables used in the file and their purpose (update as needed)**

`trainingDataPath` : path to the train.csv file

`data` : array extracted from train.csv file

`filenameAndClassIndex` : dictionary with all `ImageId_ClassId` strings from train.csv file as keys (e.g. '002cc93b.jpg_1') and row position in `data` as values (e.g. 1).  Use it to index `data` by filename (e.g. to get encodedPixels for a filename and known class) 

`numInClassCounter` : An array of counters for the number of images with defects in each class. `numInClassCounter[0]` is the number of images w/ no defects, `numInClassCounter[1]` is the number of images w/ class 1 defects... 

`filenamesByClass` : a 2D list where the first dimension corresponds to the classID (0: No defects, ..., 4: Class 4) and the second dimension is a list of filenames with **any** defects in that class

`uniqueFilenamesByClass` : a 2D list where the first dimension corresponds to the classID (0: No defects, ..., 4: Class 4) and the second dimension is a list of filenames that **only** have defects in that one class

`numClasses = 5`

# **2. Import the dataset**

---



In [0]:
#Load the dataset and list of filenames

def removeDuplicates(lst):
    #Remove duplicate values from a list
    return list(dict.fromkeys(lst))

#Returns the run-length encoded pixels for a filename and given classID from 1 to 4
def getEncodedPixels(filename, classID):
    return data[filenameAndClassIndex[filename + "_" + str(classID)]][1]

#Load train.csv as a list
trainingDataPath = "/content/train.csv"
with open(trainingDataPath, 'r') as file:
    data = list(csv.reader(file, delimiter=","))

filenameAndClassIndex = {} #dictionary to index data array based on first column's value (ImageId_ClassId)   
filenames = []

for i in range(1, len(data)):
    filenameAndClassIndex[data[i][0]] = i
    filenames.append(data[i][0].split("_")[0])
filenames = removeDuplicates(filenames) # remove duplicates from the list of filenames

# **3. Data Exploration**

---



**3.1 - Descriptive Statistics**


3.2.1 - Example images without any processing

In [0]:
#Look at a few images to try and identify the defects
#Plot a few images from each class without any masks
pltFilenames = []

numImagesPerClass = 3

fig, axs = plt.subplots(numClasses, numImagesPerClass, figsize=(numClasses*4,numImagesPerClass*4))

for i in range(numClasses):
    for j in range(numImagesPerClass): 
        #choose a random image with only one defect for each class 
        random.seed(i+j+21)
        pltFilenames.append(uniqueFilenamesByClass[i][random.randint(0, len(uniqueFilenamesByClass[i]))])
        
        filename = pltFilenames[-1]

        #Plot the image as is
        img = RGBToGrey(imageToRGBArray(getImage(filename)))
        axs[i, j].imshow(img, cmap = "gray")
        axs[i, j].set_title("Class " + str(i) + " | Picture " + str(j+1) + " | Filename: " + str(filename))
        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])
fig.tight_layout()

3.2.2 - Example images with ground-truth pixel defect masks applied in red

In [0]:
#Plot the same images from above, but with masks to see what the defects are
fig, axs = plt.subplots(numClasses, numImagesPerClass, figsize=(numClasses*4,numImagesPerClass*4))
filenameIter = iter(pltFilenames)
for i in range(numClasses):
    for j in range(numImagesPerClass): 
        #Plot the image with a defect mask applied
        filename = next(filenameIter)
        if i != 0:
            if getEncodedPixels(filename, i) != "": #apply the mask only if there is a value of encodedPixels
                img = imageToRGBArray(applyMask(getImage(filename), getEncodedPixels(filename, i), (255, 0, 0, 30)))
        else:
            #If no mask (i.e. class 0), just plot the image
            img = imageToRGBArray(getImage(filename))
        
        axs[i, j].imshow(img)
        axs[i, j].set_title("Class " + str(i) + " | Picture " + str(j+1) + " | Filename: " + str(filename))
        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])
fig.tight_layout()

3.2.3 - Example images processed with histogram equalization

In [0]:
#Plot the same images from above without masks, but this time with histogram equalization
fig, axs = plt.subplots(numClasses, numImagesPerClass, figsize=(numClasses*4,numImagesPerClass*4))
filenameIter = iter(pltFilenames)
for i in range(numClasses):
    for j in range(numImagesPerClass): 
        #Plot the image with a defect mask applied
        filename = next(filenameIter)
        img = imgHistogramEqualization(RGBToGrey(imageToRGBArray(getImage(filename))))[0]
        axs[i, j].imshow(img, cmap = "gray")
        axs[i, j].set_title("Class " + str(i) + " | Picture " + str(j+1) + " | Filename: " + str(filename))
        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])
fig.tight_layout()

3.2.4 - Example images processed with thresholding

In [0]:
#Plot the same images from above without masks, but this time with thresholding
fig, axs = plt.subplots(numClasses, numImagesPerClass, figsize=(numClasses*4,numImagesPerClass*4))
filenameIter = iter(pltFilenames)
for i in range(numClasses):
    for j in range(numImagesPerClass): 
        #Plot the image with a defect mask applied
        filename = next(filenameIter)
        img = thresholdImg(imageToRGBArray(getImage(filename)), 0.3)
        axs[i, j].imshow(img)
        axs[i, j].set_title("Class " + str(i) + " | Picture " + str(j+1) + " | Filename: " + str(filename))
        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])
fig.tight_layout()

3.2.5 - Example images processed with Sobel filters (horizontal/vertical edge detection)

refer to http://www.cs.cmu.edu/~16385/s17/Slides/4.0_Image_Gradients_and_Gradient_Filtering.pdf for more information on this filter

In [0]:
#Plot the same images from above without masks, but this time with a Sobel filter
fig, axs = plt.subplots(numClasses, numImagesPerClass, figsize=(numClasses*4,numImagesPerClass*4))
filenameIter = iter(pltFilenames)
for i in range(numClasses):
    for j in range(numImagesPerClass): 
        filename = next(filenameIter)
        img = sobelImg(RGBToGrey(imageToRGBArray(getImage(filename))))
        axs[i, j].imshow(img, cmap = "gray")
        axs[i, j].set_title("Class " + str(i) + " | Picture " + str(j+1) + " | Filename: " + str(filename))
        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])
fig.tight_layout()

3.2.6 - Example images processed with a Laplacian derivative filter and gaussian smoothing

refer to http://www.cs.cmu.edu/~16385/s17/Slides/4.0_Image_Gradients_and_Gradient_Filtering.pdf for more information on this filter

In [0]:
#Plot the same images from above without masks, but this time with a Laplace filter with gaussian smoothing
fig, axs = plt.subplots(numClasses, numImagesPerClass, figsize=(numClasses*4,numImagesPerClass*4))
filenameIter = iter(pltFilenames)
for i in range(numClasses):
    for j in range(numImagesPerClass): 
        filename = next(filenameIter)
        img = gaussianLaplaceImg(RGBToGrey(imageToRGBArray(getImage(filename))))
        axs[i, j].imshow(img, cmap = "gray")
        axs[i, j].set_title("Class " + str(i) + " | Picture " + str(j+1) + " | Filename: " + str(filename))
        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])
fig.tight_layout()

**Conclusions:**

*From pie chart*
- Of the images with defects, class 3 is the most common, followed by classes 1, 4, and 2
- A majority of the images have no defects

*From comparison matrix*
- The most common combination of multi-class defects in an image is a combination of class 3 and class 4. This is followed by class 1 and class 3, class 1 and class 2, class 2 and class 3, and class 2 and class 4.

*From pixel intensity histogram*
- Class 2 defects are generally distributed over lower intensities than any other class
- Classes 1 and 3 have similar distributions, meaning it may potentially be difficult to distinguish between them
- The majority number of pixels for images without defects (i.e. class 0's peak on the histogram), is at a higher intensity than the peak for any other class
- Class 4 pictures do not seem to peak at a certain intensity; the pixels have a relatively equal distribution of intensities 

*From visualizing defect images*
- Class 1 defects seem to be tiny transverse cracks/indentations or potentially pitting due to their relatively small size compared to the other defect classes(i.e. crack length extends into the page, but is being viewed from above) 
- Class 2 defects seem the most difficult to detect, they are barely noticeable vertical scratches on the surface. This class may be difficult to detect due to this observation, as well as the fact that the dataset has the least amount of images in this class
- Class 3 defects seem to be scratches/indentations, typically vertical, going a shallow depth into the steel
- Class 4 defects look like larger and deeper indentations/pits as compared to class 3

*From trying potential pre-proccessing filters*

Histogram equalization
- Histogram equalization seems to work well for highlighting contrast around defects in classes 3 and 4 (which are typically larger than the defects in the other classes). The technique does not work well for class 2 or class 1 where the defects are smaller in size, and generally doesn't work well for any image where there is a lot of black space and the steel does not cover the whole image. 
- This suggests that images should be pre-processed to check for empty space and then have it removed. A CNN would be able to handle inputs of different sizes, but padding would need to be added around the output so all images have consistent sizing when going through any linear, fully-connected layers.

Thresholding
- This technique was generally unhelpful. Although it identified some contrast around the larger class 4 defects, it did not do it as well as histogram equalization

# **4. Split Filenames For Training, Validation and Testing Sets**

In [0]:
#Split into training, validation, and testing set

#First split up the filenames randomly. Use a 70-15-15 split within each class, append to the set, and then remove duplicates
trainImageNames, valImageNames, testImageNames = [], [], []
for i in range(0, numClasses):
    filesForClass = filenamesByClass[i]
    random.seed(1)
    random.shuffle(filesForClass)

    trainImageNames.append(filesForClass[:int(0.7*len(filesForClass))])
    valImageNames.append(filesForClass[int(0.7*len(filesForClass)):int(0.85*len(filesForClass))])
    testImageNames.append(filesForClass[int(0.85*len(filesForClass)):len(filesForClass)])

#Flatten 2D lists to 1D
trainImageNames = list(itertools.chain(*trainImageNames))
valImageNames = list(itertools.chain(*valImageNames))
testImageNames = list(itertools.chain(*testImageNames))

#Remove duplicate namess
trainImageNames = removeDuplicates(trainImageNames)
valImageNames = removeDuplicates(valImageNames)
testImageNames = removeDuplicates(testImageNames)

# **5. Load Images and Segmentation Masks**

In [0]:
#Functions to convert encodedPixels string into a segmentation mask and vice versa
def rle_decode(encodedPixels, picWidth, picHeight): 
    
    numPixels = picWidth * picHeight
    mask = np.zeros((picHeight, picWidth))
    if encodedPixels == "":
        return mask
    
    encodedPixelsList = encodedPixels.split()
    for j in range(0, len(encodedPixelsList), 2): #iterate over locations within encodedPixels (every other value)
        for k in range(1, int(encodedPixelsList[j+1]) + 1): #create range of the size that is the number of pixels to paint
            pixelNum = int(encodedPixelsList[j]) + k - 1
            xPos = math.floor(pixelNum / picHeight)
            yPos = pixelNum % picHeight
            mask[yPos, xPos] = 1
    return mask

def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten('F')
    non_zero_indices = np.nonzero(pixels)[0]
    rle = []
    #Count numbers beside each other that are one more than the previous, remove them from non_zero_indices, 
    #then repeat until nothing is left of non_zero_indices and the full rle string is written
    while np.size(non_zero_indices) > 1:
        index = 0

        while non_zero_indices[index + 1] - non_zero_indices[index] == 1:
            index += 1
            if index >= np.size(non_zero_indices) - 1:
                break
        index += 1
        
        rle.append(non_zero_indices[0])
        rle.append(index)
        non_zero_indices = non_zero_indices[index:]
    return ' '.join(str(num) for num in rle)

In [0]:
#Custom dataset class to load the images, segmentation masks, and apply consistent transformations to them
class SteelImageDataset(Dataset):
    """Steel image dataset."""

    def __init__(self, filenames_list, rle_data, resize_width = int(1600/2), resize_height = int(256/2), apply_transforms = True, transform_apply_prob = 0.25, max_angle = 4):
        """
        Args:
            filenames_list (list): list of filenames for the class.
            rle_data (list): original train.csv file loaded as a 2D list
        """
        self.filenames_list = filenames_list
        self.rle_data = rle_data
        self.resize_width = resize_width
        self.resize_height = resize_height
        self.apply_transforms = apply_transforms #choose whether flips and rotations are done
        self.transform_apply_prob = transform_apply_prob
        self.max_angle = max_angle

    def __len__(self):
        return len(self.filenames_list)

    def transform(self, image, masks):
        new_masks, mask_tensors, temp = [], [], []
        
        # Resize
        image = cv2.resize(image, (self.resize_width, self.resize_height))
        for mask in masks:
            temp.append(cv2.resize(mask, (self.resize_width, self.resize_height)))
        new_masks = np.array(temp)
        temp = []
        
        if self.apply_transforms:
            # Random horizontal flipping with a probability of occuring 25% of the time
            if np.random.random() < self.transform_apply_prob:
                image = np.flip(image, 1)
                for mask in new_masks:
                    temp.append(np.flip(mask, 1))
                new_masks = np.array(temp)
                temp = []
            
            # Random vertical flipping with a probability of occuring 25% of the time
            if np.random.random() < self.transform_apply_prob:
                image = np.flip(image, 0)
                for mask in new_masks:
                    temp.append(np.flip(mask, 0))
                new_masks = np.array(temp)
                temp = []
        
            # Random rotations with a probability of occuring 25% of the time
            if np.random.random() < self.transform_apply_prob:
                angle = random.uniform(0, self.max_angle)
                image = ndimage.rotate(image, angle, reshape = False)
                for mask in new_masks:
                    temp.append(ndimage.rotate(mask, angle, reshape = False))
                new_masks = np.array(temp)
                temp = []
        
        # Transform to tensor
        image = torch.from_numpy(image.copy())
        mask_tensors = torch.from_numpy(new_masks.copy())
    
        return image, mask_tensors

    def __getitem__(self, idx):

        filename = self.filenames_list[idx]
        img_name = "/content/train_images/" + filename
        image = imageToRGBArray(getImage(img_name)) #Resnet needs a 3 channel img, so don't convert to RGB
        #image = RGBToGrey(imageToRGBArray(getImage(img_name)))
        masks, transformedMasks = [], []

        for i in range(1, numClasses):       
            masks.append(rle_decode(getEncodedPixels(self.filenames_list[idx], i), image.shape[1], image.shape[0]))
        masks = np.asarray(masks)

        image, transformedMasks = self.transform(image, masks)

        return image, transformedMasks 

In [0]:
train_dataset = SteelImageDataset(trainImageNames, data)
#val_dataset = SteelImageDataset(valImageNames, data)
#test_dataset = SteelImageDataset(testImageNames, data)


In [0]:
# a known image with defects in class 1 and 3
image, masks = train_dataset[4137] 
plt.imshow(image.numpy(), cmap = 'gray')

In [0]:
#verify that class 3 map is transformed correctly
plt.imshow(masks.numpy()[2], cmap = 'gray') 

In [0]:
#verify that class 1 map is transformed correctly
plt.imshow(masks.numpy()[0], cmap = 'gray') 

# **6. Model Architectures**

In [0]:
#Baseline ANN model

In [0]:
#U-Net Architecture (to use for final model) (taken from https://github.com/usuyama/pytorch-unet)

def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )

class ResNetUNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.base_model = models.resnet18(pretrained=True)
        self.base_layers = list(self.base_model.children())

        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(3, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)

        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)

        return out

# **7. Training Functions**

In [0]:
#work in progress, not yet done
def get_accuracy(model, data_loader, useGPU = True):

    correct = 0
    total = 0
    for imgs, masks in data_loader:
        #To Enable GPU Usage
        if useGPU and torch.cuda.is_available():
            imgs = imgs.cuda()
            labels = labels.cuda()
            model = model.cuda()
   
        output = model(imgs)
        
        #select index with maximum prediction score
        pred = output.max(1, keepdim=True)[1]
        correct += pred.eq(labels.view_as(pred)).sum().item()
        total += imgs.shape[0]
        #print(correct, "out of", total, "correct predictions")
    return correct / total

def plot_training_curve(iters, losses, train_acc, val_acc):
    plt.title("Training Curve")
    plt.plot(iters, losses, label="Train")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.show()

    plt.title("Training Curve")
    plt.plot(iters, train_acc, label="Train")
    plt.plot(iters, val_acc, label="Validation")
    plt.xlabel("Iterations")
    plt.ylabel("Training Accuracy")
    plt.legend(loc='best')
    plt.show()

    print("Final Training Accuracy: {}".format(train_acc[-1]))
    print("Final Validation Accuracy: {}".format(val_acc[-1]))  

def train(model, train_dataset, val_dataset, batch_size=64, num_epochs=1, learning_rate = 0.01, momentum = 0.9, useGPU = True, saveWeights = True, useAdams = False):
    
    #Put data in data loaders
    train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, 
                                            num_workers=0, shuffle=False)
    val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, 
                                        num_workers=0, shuffle=False)

    criterion = nn.CrossEntropyLoss()
    if useAdams:
        optimizer = optim.Adam(model.parameters(), learning_rate)
    else:
        optimizer = optim.SGD(model.parameters(), learning_rate, momentum)
    
    
    iters, losses, train_acc, val_acc = [], [], [], []

    if useGPU and torch.cuda.is_available():
        model = model.cuda()
        print("Training on GPU")
   
    # training
    n = 0 # the number of iterations
    for epoch in range(num_epochs):
        for i, batch in enumerate(train_data_loader):
            imgs, masks = batch
            
            imgs = imgs.permute(0, 3, 1, 2) #N, C, H, W format instead of N, H, W, C
            imgs = imgs.float()
            #imgs = imgs.unsqueeze(1) # pytorch needs the channel (see https://stackoverflow.com/questions/56789038/runtimeerror-given-groups-1-weight-of-size-64-3-3-3-expected-input4-50)
            
            #To Enable GPU Usage
            if useGPU and torch.cuda.is_available():
                imgs = imgs.cuda()
                masks = masks.cuda()

            out = model(imgs)             # forward pass

            loss = criterion(out, masks) # compute the total loss
            loss.backward()               # backward pass (compute parameter updates)
            optimizer.step()              # make the updates for each parameter
            optimizer.zero_grad()         # a clean up step for PyTorch

            # save the current training information
            iters.append(n)
            losses.append(float(loss)/batch_size)             # compute *average* loss
            train_acc.append(get_accuracy(model, train_data_loader, useGPU)) # compute training accuracy 
            val_acc.append(get_accuracy(model, val_data_loader, useGPU))  # compute validation accuracy
            print("Iteration: ", str(n), "| Train Loss: ", losses[n], "| Train Accuracy: ", train_acc[n], "| Validation Accuracy: ", val_acc[n])
            
            n += 1
            
        # Save the current model (checkpoint) to a file
        if saveWeights:
            model_path = "model_{0}_bs{1}_lr{2}_epoch{3}".format(model.name,
                                                       batch_size,
                                                       str(learning_rate).replace('.', '-'),
                                                       epoch)
            torch.save(model.state_dict(), model_path + ".pth")
            
    # Write the train/test loss/err into CSV file for plotting later
    if saveWeights:
        epochs = np.arange(1, num_epochs + 1)
        np.savetxt("{}_train_loss.csv".format(model_path), losses)
        np.savetxt("{}_train_acc.csv".format(model_path), train_acc)
        np.savetxt("{}_val_acc.csv".format(model_path), val_acc)            

    # plotting
    plot_training_curve(iters, losses, train_acc, val_acc)

In [0]:
model = ResNetUNet(4)
train(model, train_dataset, train_dataset, batch_size = 3)

# X. Functions for painting defects 

In [0]:
#Function to load, paint, and save pictures given a list of filenames. Outputs jpg file in folder specified within function

def loadPaintAndSavePictures(filenames, filenameAndClassIndex):
    #Defect colours by class: 1 - Red, 2 - Green, 3 - Blue, 4 - Yellow. Opacity: 1 = transparent, 255 = opaque
    opacity = 30
    colour = [(255, 0, 0, opacity), (0, 255, 0, opacity), (0, 0, 255, opacity), (255, 255, 0, opacity)]

    for file in filenames:
        #Load the picture
        picture = getImage(file)

        # Paint the picture for each defect type
        for i in range(1,5):
            if file + "_" + str(i) in filenameAndClassIndex.keys():
                encodedPixels = getEncodedPixels(file, i)
                picture = applyMask(picture, encodedPixels, colour[i - 1])

        #Save the picture
        highlightsPath = "/content/train_images_highlighted_defects" 
        if not os.path.exists(highlightsPath):
            os.makedirs(highlightsPath)
        os.chdir(highlightsPath)
        picture.save(file.split(".jpg")[0] + " highlighted.jpg")

    print("Highlighting completed")


In [0]:
#Paint the first x pictures from the dataset

x  = 4
loadPaintAndSavePictures(filenames[:x], filenameAndClassIndex)

In [0]:
#Zip the highlighted images to a folder
!zip -r "/content/train_images_highlighted_defects.zip"  "/content/train_images_highlighted_defects" -q