In [1]:
import segmentation_models_pytorch as smp
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights, lraspp_mobilenet_v3_large, deeplabv3_resnet101
import label_test_script
from label_test_script import visualize, reverse_one_hot, colour_code_segmentation

#Hyperparamters
ENCODER = 'resnet101'
ENCODER_WEIGHTS = 'imagenet' #pretrained weighting
#CLASSES = ["background", "skin", "nose", "right_eye", "left_eye", "right_brow", "left_brow", "right_ear", "left_ear", "mouth_interior", "top_lip", "bottom_lip", "neck", "hair", "beard", "clothing", "glasses", "headwear", "facewear"]
ACTIVATION = "sigmoid" # softmax2d for multiclass segmentation
num_classes = 11


preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [3]:
import dlib
import glob

datFile = "/home/nathan/Documents/final_project/shape_predictor_5_face_landmarks.dat"
face_cascade = cv2.CascadeClassifier('haarcascade_frontalface_alt2.xml')
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor(datFile)  

SIZEX = 128
SIZEY = 128

def shape_to_normal(shape):
  shape_normal = []
  for i in range(0, 5):
      shape_normal.append((i, (shape.part(i).x, shape.part(i).y)))
  return shape_normal

def get_eyes_nose_dlib(shape):
    nose = shape[4][1]
    left_eye_x = int(shape[3][1][0] + shape[2][1][0]) // 2
    left_eye_y = int(shape[3][1][1] + shape[2][1][1]) // 2
    right_eyes_x = int(shape[1][1][0] + shape[0][1][0]) // 2
    right_eyes_y = int(shape[1][1][1] + shape[0][1][1]) // 2
    return nose, (left_eye_x, left_eye_y), (right_eyes_x, right_eyes_y)

def distance(a, b):
    return np.sqrt((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2)

def cosine_formula(length_line1, length_line2, length_line3):
  cos_a = -(length_line3 ** 2 - length_line2 ** 2 - length_line1 ** 2) / (2 * length_line2 * length_line1)
  return cos_a

def rotate_point(origin, point, angle):
    ox, oy = origin
    px, py = point

    qx = ox + np.cos(angle) * (px - ox) - np.sin(angle) * (py - oy)
    qy = oy + np.sin(angle) * (px - ox) + np.cos(angle) * (py - oy)
    return qx, qy

def is_between(point1, point2, point3, extra_point):
    c1 = (point2[0] - point1[0]) * (extra_point[1] - point1[1]) - (point2[1] - point1[1]) * (extra_point[0] - point1[0])
    c2 = (point3[0] - point2[0]) * (extra_point[1] - point2[1]) - (point3[1] - point2[1]) * (extra_point[0] - point2[0])
    c3 = (point1[0] - point3[0]) * (extra_point[1] - point3[1]) - (point1[1] - point3[1]) * (extra_point[0] - point3[0])
    if (c1 < 0 and c2 < 0 and c3 < 0) or (c1 > 0 and c2 > 0 and c3 > 0):
        return True
    else:
        return False

def align(img):
    gray = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
    #face alignment
    rects = detector(gray, 0)
    if len(rects) > 0:
        for rect in rects:
            x = rect.left()
            y = rect.top()
            w = rect.right()
            h = rect.bottom()
            shape = predictor(gray, rect)
    else:
      return 0

    shape = shape_to_normal(shape)
    nose, left_eye, right_eye = get_eyes_nose_dlib(shape)

    center_of_forehead = ((left_eye[0] + right_eye[0]) // 2, (left_eye[1] + right_eye[1]) // 2)

    center_pred = (int((x + w) / 2), int((y + y) / 2))

    length_line1 = distance(center_of_forehead, nose)
    length_line2 = distance(center_pred, nose)
    length_line3 = distance(center_pred, center_of_forehead)

    cos_a = cosine_formula(length_line1, length_line2, length_line3)
    angle = np.arccos(cos_a)

    rotated_point = rotate_point(nose, center_of_forehead, angle)
    rotated_point = (int(rotated_point[0]), int(rotated_point[1]))
    if is_between(nose, center_of_forehead, center_pred, rotated_point):
        angle = np.degrees(-angle)
    else:
        angle = np.degrees(angle)
    
    #gray = Image.fromarray(gray)
    #gray = np.array(gray.rotate(angle))

    return angle

def crop_rotate(img):
    """ HAAR CASCADE CLASSIFIER AND ROATER"""
    #face_cascade = cv2.CascadeClassifier('haarcascade_frontalface_alt2.xml')

    crops = []
    original = 0
    skips = []
    rejected = []
    angles = []

    #datFile = "/home/nathan/Documents/final_project/shape_predictor_5_face_landmarks.dat"

    # convert to gray
    gray = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)

    face = None
    # Detect faces
    try:
      face_cascade=cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_alt2.xml")
      limit = int((gray.shape[0]/2))
      faces = face_cascade.detectMultiScale(gray, 1.05, 2, minSize=[limit,0])
      face = sorted(faces,key=lambda f:f[2]*f[3])[-1]
    except:
      #print("FAILED USING FRONTAL FACE ALT 2")
      try:
        face_cascade=cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_profileface.xml")
        limit = int((gray.shape[0]/2))
        faces = face_cascade.detectMultiScale(gray, 1.05, 2, minSize=[limit,0])
        face = sorted(faces,key=lambda f:f[2]*f[3])[-1]
      except:
        try:
          face_cascade=cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
          limit_1 = int((gray.shape[0]/2.5))
          limit_2 = int((gray.shape[1]/1.7))
          faces = face_cascade.detectMultiScale(gray, 1.05, 3, minSize=[limit_1,0], maxSize=[50000,limit_2]) 
          face = sorted(faces,key=lambda f:f[2]*f[3])[-1]  
        except:
          #print("FAILED USING FRONTAL FACE ALT 2")
          try:
            face_cascade=cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
            limit_1 = int((gray.shape[0]/2.5))
            limit_2 = int((gray.shape[1]/2.5))
            faces = face_cascade.detectMultiScale(gray, minSize=[0,limit_2])
            face = sorted(faces,key=lambda f:f[2]*f[3])[-1]
          except:
            #print("FAILED USING FRONTAL FACE ALT 2")
            try:
              face_cascade=cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_alt.xml")
              limit_1 = int((gray.shape[0]/2))
              limit_2 = int((gray.shape[1]/3))
              faces = face_cascade.detectMultiScale(gray, 1.03, 1, minSize=[limit_1,0])
              face = sorted(faces,key=lambda f:f[2]*f[3])[-1]
            except:
              #print("FAILED USING PROFILE FACE ALT")
              pass

    if face is not None:
      x, y, w, h = face
      crops.append(face)
      face = img[y:y + h, x:x + w]
      face = cv2.resize(face, (SIZEY, SIZEX))
      #cv2_imshow(face)

      angle = align(img)
      face = Image.fromarray(face)
      face_rotated = np.array(face.rotate(angle))
      angles.append(angle)

    return (face_rotated)

In [4]:
#import torchvision.transforms as T
#import torchvision.transforms.functional as F
import albumentations as albu
import random
import scipy
import torch
import os
import cv2
import numpy as np
#from google.colab.patches import cv2_imshow
from matplotlib import pyplot as plt
import re

rgb_vals = [0,1,2,3,4,5,6,7,8,9,10,11]
#rgb_vals = [0,1,2,3,4,5,6,7,8,9,10]


def to_tensor(x, **kwargs):
    x_t = x.transpose(2, 0, 1).astype('float32')
    #print("XTSHAPE", x_t.shape)
    return x_t


def get_preprocessing(preprocessing_fn):
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

def get_training_augmentation():
    train_transform = [

        #albu.HorizontalFlip(p=0.5),
        albu.Rotate((-18,18)),
        albu.PadIfNeeded(min_height=SIZEY, min_width=SIZEY, always_apply=True, border_mode=0),
        #albu.RandomCrop(height=320, width=320, always_apply=True),
        albu.Perspective(p=0.5),
        albu.GaussNoise(p=0.2),
        albu.OneOf([albu.CLAHE(p=1),albu.RandomBrightness(p=1),albu.RandomGamma(p=1),],p=0.9,),
        albu.OneOf([albu.Sharpen(p=1),albu.Blur(blur_limit=3, p=1),],p=0.9,),albu.OneOf([albu.RandomContrast(p=1),albu.HueSaturationValue(p=1),],p=0.9,),
    ]

    return albu.Compose(train_transform)


def transformation_augs():
    train_transform = [
        #albu.HorizontalFlip(p=0.5),

    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.PadIfNeeded(SIZEY, SIZEY)
    ]
    return albu.Compose(test_transform)

class MyDataSet(torch.utils.data.Dataset):

  #CLASSES =  ["background","facial_skin","left_brow","right_brow","left_eye","right_eye", "nose","upper_lip","inner_mouth","lower_lip","hair"]

  def __init__(self, images_dir, masks_dir, coords_dir, preprocessing=None, classes=None,augmentation=None, mode="train", use_landmarks=True):
    super(MyDataSet, self).__init__()
    
    # store the augmented tensors tensors
    #self._x, self._y = x,y
    self.preprocessing = preprocessing
    self.augmentation = augmentation

    self.image_ids = [os.path.join(images_dir, f) for f in os.listdir(images_dir) if f.endswith('.jpg')]

    self.use_landmarks = use_landmarks

    if mode == "val":
      self.masks_fps = [os.path.join(images_dir, image_id) for image_id in self.image_ids]
      self.images_fps = [os.path.join(masks_dir, mask_id.replace("jpg","png")) for mask_id in self.image_ids]
      self.coords_fps = [os.path.join(coords_dir, coords_id.replace(".jpg", "_landmark.txt")) for coords_id in self.image_ids]

  def __len__(self):
    # a DataSet must know it size
    return len(self.images_fps)

  def __getitem__(self, i, put_back=False):

    #print(self.masks_fps[i], self.images_fps[i])

    image = cv2.imread(self.masks_fps[i])
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask = cv2.imread(self.images_fps[i],0)

    # crop to center face
    if self.use_landmarks == True:
      smallest_x = 99999
      smallest_y = 99999
      biggest_x = -99999
      biggest_y = -99999
      with open(self.coords_fps[i], 'rb') as f:

        contents = str(f.read()).split("\\n")
        #print(contents)
        contents = [[int(float(single.replace("\\r", "").
                              replace("'", "").replace("b", "").replace("\\x1a", "")
                              )) for single in pair.split(" ")] for pair in contents[2:-2]]
        #print(contents)

        #contents = np.array(contents)
        for pair in contents:
          #plt.scatter((pair[0]), (pair[1]), color="red")

          if (pair[0]) < smallest_x:
            smallest_x = (pair[0])
          
          if (pair[0]) > biggest_x:
            biggest_x = (pair[0])

          if (pair[1]) < smallest_y:
            smallest_y = (pair[1])
          
          if (pair[1]) > biggest_y:
            biggest_y = (pair[1])

        #save original for repositioning
        crop_coords = (int(smallest_y), int(biggest_y), int(smallest_x), int(biggest_x))
        original_image = image
        original_mask = mask

        #crop to points of ineterst
        image = image[smallest_y: biggest_y, smallest_x:biggest_x]
        mask = mask[smallest_y: biggest_y, smallest_x:biggest_x]

      try:
        mask = cv2.resize(mask, (SIZEY, SIZEY))
        image = cv2.resize(image, (SIZEY, SIZEY))
        mask = np.expand_dims(mask,2)
      except:
        print(self.masks_fps[i])
        print(self.images_fps[i])
        plt.imshow(mask)
        plt.imshow(image)
    
    else:
      # use haar classifier
      crop_and_rotate()

    # smooth mask
    # mask = mask = scipy.ndimage.median_filter(mask, 5)

    # apply augmentations
    if self.augmentation:
        sample = self.augmentation(image=image, mask=mask)
        image, mask = sample['image'], sample['mask']
    
    # apply preprocessing
    if self.preprocessing:
        sample = self.preprocessing(image=image, mask=mask)
        image, mask = sample['image'], sample['mask']

    #onehot
    one_hot_Y = torch.nn.functional.one_hot(torch.tensor(mask).to(torch.int64), num_classes).permute(0,3,1,2).float().squeeze(0)

    if put_back == True:
      return (image, one_hot_Y, original_image, original_mask, crop_coords)

    else:
      return (image, one_hot_Y)

  def get_og(self, i):
    return self.__getitem__(i, put_back=True)


val_img_path = "/home/nathan/Documents/final_project/datasets/ibugmask_release/test"
val_mask_path = val_img_path
val_coord_path= val_img_path

val_ds = MyDataSet(val_mask_path,
                   val_img_path,
                   val_coord_path,
                   preprocessing=get_preprocessing(preprocessing_fn),
                   augmentation=get_validation_augmentation(), 
                   mode="val")

image_vis, gt_mask = val_ds[20]

In [5]:
from matplotlib import pyplot as plt

print(len(val_ds))

for x in range(0):

  image_vis, gt_mask = val_ds[x]
  print(x, ":", image_vis.shape, gt_mask.shape)

  gt_mask = colour_code_segmentation(reverse_one_hot(torch.tensor(gt_mask)), rgb_vals)

  visualize(
      original_image = image_vis[0,::],
      ground_truth_mask = gt_mask
  )

#np.unique(image_vis)

154


## Initialise Models

In [6]:
from segmentation_models_pytorch import utils
DEVICE = torch.device("cuda")

## Train Model

### Function to view train model precitions

In [7]:
from sklearn.metrics import jaccard_score, f1_score
import tensorflow as tf
from torchmetrics.classification import F1Score, BinaryF1Score, MulticlassF1Score, JaccardIndex
from torchmetrics import Dice
from sklearn.preprocessing import MultiLabelBinarizer
from torchvision.utils import save_image
from PIL import Image
from skimage.transform import resize
from scipy.ndimage.filters import gaussian_filter

def average(lst):
    return sum(lst) / len(lst)

def view_predictions(model, ds, numm_classes, amount=-1, visualise=False):

    ious = []
    f1s = []

    if amount == -1:
      amount = len(ds)

    #predict
    for idx in range(amount):

        image, gt_mask, og_image, og_mask, coords = ds.get_og(idx)
        image = image
        image_vis = image
        image_vis = np.transpose(image_vis,(1,2,0))
        x_tensor = torch.tensor(image).to(DEVICE).unsqueeze(0)
        pred_mask = model(x_tensor)
        pred_mask = pred_mask.detach().squeeze().cpu().numpy()
        pred_mask = colour_code_segmentation(reverse_one_hot(torch.tensor(pred_mask)), rgb_vals)
        gt_mask = colour_code_segmentation(reverse_one_hot(torch.tensor(gt_mask)), rgb_vals)

        #plt.imshow(og_image,interpolation='none')
        #plt.imshow(gt_mask, alpha=0.2,interpolation='none') # interpolation='none'
        #plt.show()
        
        #get IoU score
        m = tf.keras.metrics.MeanIoU(num_classes=numm_classes)
        m.update_state(gt_mask, pred_mask)
        iou = m.result().numpy()
        #print("MEAN IoU:" , iou)
        ious.append(iou)

        metric = MulticlassF1Score(num_classes=numm_classes, average=None, labels=np.unique(pred_mask) ,validate_args=True)
        #metric = JaccardIndex(task="multiclass",num_classes=numm_classes, average=None, labels=np.unique(pred_mask) ,validate_args=True)
        f1 = metric(torch.tensor(gt_mask), torch.tensor(pred_mask))
        if len(f1) == num_classes:
          f1[f1 <0.1] = 0.1
          #print(f1)
          f1s.append(np.array(f1))

        if visualise:
          visualize(
              img_crop = image.T,
              gt_crop= gt_mask,
              pred_crop = pred_mask,
          )

          og_image = og_image
          upscale_coords = og_image[coords[0]: coords[1], coords[2]:coords[3]].shape
          pred_mask = Image.fromarray(pred_mask.astype(np.uint8)).resize(upscale_coords[:-1][::-1])
          pred_mask = cv2.cvtColor(np.array(pred_mask),cv2.COLOR_GRAY2RGB)
          pred_mask =Image.fromarray(pred_mask)

          [xs,ys]=pred_mask.size  #width*height
          colour_dict = {0:(0,0,0), 1:(39,65,135), 2:(70,136,154), 3:(52,158,136), 4:(37,157,97), 5:(23,180,23), 
                          6:(22,180,23), 7:(113,203,58), 8:(219,213,68), 9:(230,109,11), 10:(255,56,10), 11:(0,0,0), 12:(0,0,0)}
                          
          # Examine every pixel in im
          for x in range(0,xs):
            for y in range(0,ys):
              #get the RGB color of the pixel
              [r,g,b]=pred_mask.getpixel((x, y))
              value = colour_dict[r]
              pred_mask.putpixel((x, y), value)
          
          print(np.array(pred_mask).shape, og_image.shape, upscale_coords)

          pred_mask = np.array(pred_mask)
          for x, rows in enumerate(pred_mask):
            for y, columns in enumerate(rows):
              r = pred_mask[x,y][0]
              if r != 0:

                og_image[x+int(coords[0]),y+int(coords[2])] = pred_mask[x,y]
              else:
                pass

          plt.imshow(og_image,interpolation='none')
          #plt.imshow(gt_mask, alpha=0.2,interpolation='none') # interpolation='none'
          plt.show()

        print(idx)

    
    fs1_numpy = np.array(f1s)
    av_f1s = np.nanmean(fs1_numpy, axis=0)
    #av_f1s = fs1_numpy.mean(axis=0)
    av_f1s_av = av_f1s.mean(axis=0)

    print ("Dataset MIoU = ", average(ious))
    print ("Dataset F1 = ", av_f1s)
    print ("Dataset F1 av = ", av_f1s_av)

2023-02-16 15:57:26.592547: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-16 15:57:26.953923: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-02-16 15:57:28.060574: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/nathan/miniconda3/lib/python3.9/site-packages/cv2/../../lib64::/home/nathan/miniconda3/lib/:/home/nathan/miniconda3/lib/:/home/nathan/miniconda3/lib/
2023-02-16 15:57:28.060706: W tensorflow/stream_executor

In [8]:
from torchvision.utils import save_image
from PIL import Image



def save_predictions(model, ds):

  xs = []
  ys = []

  #predict
  for idx in range(len(ds)):

      # Pop image from DS
      image, gt_mask = ds[idx]
      image_vis = image
      image_vis = np.transpose(image_vis,(1,2,0))
      
      # Reshape
      x_tensor = torch.tensor(image).to(DEVICE).unsqueeze(0)
      # Predict test image
      pred_mask = model(x_tensor)
      # Reshape
      pred_mask = pred_mask.detach().squeeze().cpu()

      print("saving", idx, "/", len(ds))

      # Save Predictions for use in Label Adapter
      xs.append(pred_mask)

      ys.append(gt_mask)
      
  return xs, ys

#save_predictions(model, val_ds)

# Predict with Deeplabv3+

In [9]:
model = torch.load("/home/nathan/Documents/final_project/saved_models/helen/deeplab.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes, amount=-1, visualise=False)

model = torch.load("/home/nathan/Documents/final_project/saved_models/helen/fcn.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes, amount=-1, visualise=False)

model = torch.load("/home/nathan/Documents/final_project/saved_models/helen/unet.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes, amount=-1, visualise=False)

model = torch.load("/home/nathan/Documents/final_project/saved_models/helen/mobile.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes, amount=-1, visualise=False)

  gt_mask = colour_code_segmentation(reverse_one_hot(torch.tensor(gt_mask)), rgb_vals)
2023-02-16 15:57:33.766146: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-16 15:57:33.773399: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-16 15:57:33.773659: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-02-16 15:57:33.774231: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To 

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
Dataset MIoU =  0.4504468680976273
Dataset F1 =  [0.7040327  0.8971815  0.41992015 0.43955752 0.47745544 0.47638896
 0.8513073  0.5610244  0.51092225 0.6255943  0.40041447]
Dataset F1 av =  0.5785272
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100


### Load and view model predictions

In [10]:

"""
print("MIXED:")
print("deep")
model = torch.load("/home/nathan/Documents/final_project/saved_models/mixed/deeplab.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes)

print("FCN")
model = torch.load("/home/nathan/Documents/final_project/saved_models/mixed/fcn.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes)

print("mobile")
model = torch.load("/home/nathan/Documents/final_project/saved_models/mixed/mobile.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes)

print("unet")
model = torch.load("/home/nathan/Documents/final_project/saved_models/mixed/unet.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes)



print("HELEN:")
print("deep")
model = torch.load("/home/nathan/Documents/final_project/saved_models/helen/deeplab.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes)

print("FCN")
model = torch.load("/home/nathan/Documents/final_project/saved_models/helen/fcn.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes)

print("mobile")
model = torch.load("/home/nathan/Documents/final_project/saved_models/helen/mobile.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes)

print("unet")
model = torch.load("/home/nathan/Documents/final_project/saved_models/helen/unet.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes)



print("LAPA:")
print("deep")
model = torch.load("/home/nathan/Documents/final_project/saved_models/lapa/deeplab.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes)

print("FCN")
model = torch.load("/home/nathan/Documents/final_project/saved_models/lapa/fcn.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes)

print("mobile")
model = torch.load("/home/nathan/Documents/final_project/saved_models/lapa/mobile.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes)

print("unet")
model = torch.load("/home/nathan/Documents/final_project/saved_models/lapa/unet.pth", map_location=DEVICE)
view_predictions(model,val_ds, num_classes)


print("SYNTH:")
print("deep")
model = torch.load("/home/nathan/Documents/final_project/saved_models/synth/deeplab.pth", map_location=DEVICE)
view_predictions(model,val_ds, 12)

print("FCN")
model = torch.load("/home/nathan/Documents/final_project/saved_models/synth/fcn.pth", map_location=DEVICE)
view_predictions(model,val_ds, 12)
"""
#print("mobile")
#model = torch.load("/home/nathan/Documents/final_project/saved_models/synth/mobile.pth", map_location=DEVICE)
#view_predictions(model,val_ds, 12)

#print("unet")
#model = torch.load("/home/nathan/Documents/final_project/saved_models/synth/unet.pth", map_location=DEVICE)
#view_predictions(model,val_ds, 12)


'\nprint("MIXED:")\nprint("deep")\nmodel = torch.load("/home/nathan/Documents/final_project/saved_models/mixed/deeplab.pth", map_location=DEVICE)\nview_predictions(model,val_ds, num_classes)\n\nprint("FCN")\nmodel = torch.load("/home/nathan/Documents/final_project/saved_models/mixed/fcn.pth", map_location=DEVICE)\nview_predictions(model,val_ds, num_classes)\n\nprint("mobile")\nmodel = torch.load("/home/nathan/Documents/final_project/saved_models/mixed/mobile.pth", map_location=DEVICE)\nview_predictions(model,val_ds, num_classes)\n\nprint("unet")\nmodel = torch.load("/home/nathan/Documents/final_project/saved_models/mixed/unet.pth", map_location=DEVICE)\nview_predictions(model,val_ds, num_classes)\n\n\n\nprint("HELEN:")\nprint("deep")\nmodel = torch.load("/home/nathan/Documents/final_project/saved_models/helen/deeplab.pth", map_location=DEVICE)\nview_predictions(model,val_ds, num_classes)\n\nprint("FCN")\nmodel = torch.load("/home/nathan/Documents/final_project/saved_models/helen/fcn.pt

# label adapter

In [11]:
xs, ys = save_predictions(model, train_ds)

NameError: name 'train_ds' is not defined

In [None]:
len(xs)

In [None]:
for x in range(5):
  visualize(
      original_image = colour_code_segmentation(reverse_one_hot(xs[x]), rgb_vals),
      ground_truth_mask = colour_code_segmentation(reverse_one_hot(ys[x]), rgb_vals),
      #ground_truth_mask = colour_code_segmentation(reverse_one_hot(ys[x]), rgb_vals)
  )


In [None]:
X2 = torch.stack(xs)
Y2 = torch.stack(ys)
print(X2.shape)
print(Y2.shape)

In [None]:
Y2 = Y2.float()
X2 = X2.float()
Y2.dtype

In [None]:
from sklearn.model_selection import train_test_split
x_train2, x_test2, y_train2, y_test2 = train_test_split(X2, Y2, test_size=0.2, shuffle=True)

print(x_train2.shape)
print(y_train2.shape)


#X = X.numpy().reindex(np.random.permutation(X.index))
#one_hot_Y = one_hot_Y.numpy().reindex(np.random.permutation(one_hot_Y.index))

In [None]:
for x in range(10):
  visualize(
      original_image =  reverse_one_hot(x_train2[x]),
      ground_truth_mask = reverse_one_hot(y_train2[x]),
      #one_hot_encoded_mask = reverse_one_hot(y_test[x])
  )


## create our datasets

In [None]:
import torchvision.transforms as T
import torchvision.transforms.functional as F
import random

class MyDataSet(torch.utils.data.Dataset):
  def __init__(self, x, y):
    super(MyDataSet, self).__init__()
    
    # store the augmented tensors tensors
    self._x, self._y = x,y

  def __len__(self):
    # a DataSet must know it size
    return self._x.shape[0]

  def __getitem__(self, index):
    x = torch.tensor(np.expand_dims(colour_code_segmentation(reverse_one_hot(torch.tensor(self._x[index, :])), rgb_vals),0).astype(float)).to(device="cuda", dtype=torch.float)
    y = self._y[index, :]
    # print("GETTING ITEM")
    return x, y

train_ds2 = MyDataSet(x_train2, y_train2)
val_ds2 = MyDataSet(x_test2, y_test2)


### Data Loader

In [None]:
from torch.utils.data import DataLoader

# Get train and val data loaders
train_loader2 = DataLoader(train_ds2, batch_size=64, shuffle=True)
valid_loader2 = DataLoader(val_ds2, batch_size=20, shuffle=False)

In [None]:
print(len(train_loader2))

In [None]:
label_adapter = smp.Unet(
    in_channels=1,
    encoder_name=ENCODER, 
    encoder_weights="imagenet", 
    classes=num_classes, 
    activation=ACTIVATION,
    #encoder_depth = 18,
    #decoder_channels = 18,
    decoder_use_batchnorm = True,
    #aux_params=aux_params
)
#preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
train_model(label_adapter, train_loader2, valid_loader2, "/content/drive/MyDrive/FRESH/label_adapter.pth", 100)

### Load and view model predictions

In [None]:
from sklearn.metrics import jaccard_score, f1_score
import tensorflow as tf
from torchmetrics.classification import F1Score, BinaryF1Score, MulticlassF1Score, JaccardIndex
from torchmetrics import Dice
from sklearn.preprocessing import MultiLabelBinarizer
from torchvision.utils import save_image
from PIL import Image

def average(lst):
    return sum(lst) / len(lst)

def view_label_predictions(model, ds, numm_classes ):

    ious = []
    f1s = []

    #predict
    for idx in range(len(ds)):

        image, gt_mask = ds[idx]
        image = image.cpu()
        image_vis = image.cpu()
        image_vis = np.transpose(image_vis,(1,2,0))
        print("vis:",image_vis.shape)
        print("im:",image.shape)
        
        x_tensor = torch.tensor(image).to(DEVICE).unsqueeze(0)
        #print("X_TENSOR:", x_tensor, x_tensor.shape)
        # Predict test image
        pred_mask = model(x_tensor)
        print("predraw", pred_mask.shape)
        print("gtraw", gt_mask.shape)
        pred_mask = pred_mask.detach().squeeze().cpu().numpy()
        # Convert pred_mask from `CHW` format to `HWC` format
        print(pred_mask.shape)
        # Get prediction channel corresponding to face
        pred_mask = colour_code_segmentation(reverse_one_hot(torch.tensor(pred_mask)), rgb_vals)
        print(pred_mask.shape)
        
        # Convert gt_mask from `CHW` format to `HWC` format
        print(gt_mask.shape)
        gt_mask = colour_code_segmentation(reverse_one_hot(torch.tensor(gt_mask)), rgb_vals)
        
        #get IoU score
        m = tf.keras.metrics.MeanIoU(num_classes=numm_classes)
        m.update_state(gt_mask, pred_mask)
        iou = m.result().numpy()
        #print("MEAN IoU:" , iou)
        ious.append(iou)

        #gt_mask = scipy.signal.medfilt(gt_mask, 9)

        #Get f1
        #m = MultiLabelBinarizer().fit(gt_mask)
        #f1 = f1_score(m.transform(gt_mask), m.transform(pred_mask), average=None)
        #if len(f1) == num_classes:
        #  f1s.append(f1)

        metric = MulticlassF1Score(num_classes=numm_classes, average=None, labels=np.unique(pred_mask) ,validate_args=True)
        #metric = JaccardIndex(task="multiclass",num_classes=numm_classes, average=None, labels=np.unique(pred_mask) ,validate_args=True)
        f1 = metric(torch.tensor(pred_mask), torch.tensor(gt_mask))
        if len(f1) == num_classes:
          f1[f1 <0.1] = np.nan
          print(f1)
          f1s.append(np.array(f1))

        try:
          if idx < 20:
            visualize(
                original_image = image[0,::],
                ground_truth_mask = gt_mask,
                predicted_mask = pred_mask,
            )
        except:
          if idx < 20:
            visualize(
                original_image = image.cuda()[0,::],
                ground_truth_mask = gt_mask.cuda(),
                predicted_mask = pred_mask.cuda(),
            )

    
    fs1_numpy = np.array(f1s)
    av_f1s = np.nanmean(fs1_numpy, axis=0)
    #av_f1s = fs1_numpy.mean(axis=0)
    av_f1s_av = av_f1s.mean(axis=0)

    print ("Dataset MIoU = ", average(ious))
    print ("Dataset F1 = ", av_f1s)
    print ("Dataset F1 av = ", av_f1s_av)

In [None]:
model = torch.load("/content/drive/MyDrive/FRESH/label_adapter.pth", map_location=DEVICE)

view_label_predictions(model,val_ds2, num_classes)