# Faces in the Wild with GCN Semantic Segmentation

This code snippet runs the gcn semantic segmentation network on a single image.   
*Eurecat 2019 - Rafael Redondo*

In [0]:
# Make sure your network has been trained with this architectural parameters
target_size = 256
num_classes = 7
num_levels = 3

# I/O
checkpoint = "./gcn-epoch_0480.pth"
input_file = "./Bettina_Rheims_0001.jpg"
result_file = "./Bettina_Rheims_0001_results.png"

# Colorize your labeled classes
label_colors = [
    (0, 0, 0),
    (0, 255, 0),
    (255, 0, 0),
    (0, 255, 255),
    (0, 0, 255),
    (255, 0, 255),
    (255, 255, 0)]

In [0]:
# Prepare Drive by following the instructions.
from google.colab import drive
drive.mount('/content/drive') 

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
# Go to Drive contents
cd drive/My\ Drive/ELFW

/content/drive/My Drive/ELFW


In [0]:
# Import
import torch
from torch.autograd import Variable
from torchvision.transforms import ToTensor, Normalize
import torch.nn.functional as F
from models import GCN
from PIL import Image
import numpy as np

# Model loading (Resnet may take a while to download)
model = torch.nn.DataParallel(GCN(num_classes, num_levels))
model.load_state_dict(torch.load(checkpoint))
model.cuda()
model.eval()
;

In [0]:
# Pass forward
image = Image.open(input_file).convert("RGB")
image = image.resize((target_size, target_size), Image.BILINEAR)
img = ToTensor()(image)
img = Normalize([.485, .456, .406], [.229, .224, .225])(img)
img = Variable(img).cuda().unsqueeze(0)
scores = model(img)  # first image in batch
label_probs = F.log_softmax(scores[0], dim=0).cpu().detach().numpy()



In [0]:
# Composite
rgb = np.zeros((target_size, target_size, 3))
labels = np.argmax(label_probs, axis=0)

for l in range(len(label_probs)):
    indexes = labels == l
    for c in range(3):
        rgb[:, :, c][indexes] = label_colors[l][c]

result = Image.fromarray(rgb.astype('uint8'))
result.save(result_file)
print('Results saved.')

Results saved.
